diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/COO.h @@ -80,12 +80,30 @@ /// an intermediate representation; e.g., for reading sparse tensors /// from external formats into memory, or for certain conversions between /// different `SparseTensorStorage` formats. +/// +/// This class provides all the typedefs required by the "Container" +/// concept (); +/// however, beware that it cannot fully implement that concept since +/// it cannot have a default ctor (because the `dimSizes` field is const). +/// Thus these typedefs are provided for familiarity reasons, rather +/// than as a proper implementation of the concept. template class SparseTensorCOO final { public: + using value_type = const Element; + using reference = value_type &; + using const_reference = reference; + // The types associated with `std::vector` differ significantly between + // C++11/17 vs C++20; so we explicitly defer to whatever `std::vector` + // says the types should be. + using vector_type = std::vector>; + using iterator = typename vector_type::const_iterator; + using const_iterator = iterator; + using difference_type = typename vector_type::difference_type; + using size_type = typename vector_type::size_type; + SparseTensorCOO(const std::vector &dimSizes, uint64_t capacity) - : dimSizes(dimSizes), isSorted(true), iteratorLocked(false), - iteratorPos(0) { + : dimSizes(dimSizes), isSorted(true) { if (capacity) { elements.reserve(capacity); indices.reserve(capacity * getRank()); @@ -129,12 +147,12 @@ /// Resolving such conflicts is left up to clients of the iterator /// interface. /// + /// This method invalidates all iterators. + /// /// Asserts: - /// * is not in iterator mode /// * the `ind` is valid for `rank` /// * the elements of `ind` are valid for `dimSizes`. void add(const std::vector &ind, V val) { - assert(!iteratorLocked && "Attempt to add() after startIterator()"); const uint64_t *base = indices.data(); uint64_t size = indices.size(); uint64_t rank = getRank(); @@ -161,44 +179,25 @@ elements.push_back(addedElem); } + const_iterator begin() const { return elements.cbegin(); } + const_iterator end() const { return elements.cend(); } + /// Sorts elements lexicographically by index. If an index is mapped to /// multiple values, then the relative order of those values is unspecified. /// - /// Asserts: is not in iterator mode. + /// This method invalidates all iterators. void sort() { - assert(!iteratorLocked && "Attempt to sort() after startIterator()"); if (isSorted) return; std::sort(elements.begin(), elements.end(), getElementLT()); isSorted = true; } - /// Switches into iterator mode. If already in iterator mode, then - /// resets the position to the first element. - void startIterator() { - iteratorLocked = true; - iteratorPos = 0; - } - - /// Gets the next element. If there are no remaining elements, then - /// returns nullptr and switches out of iterator mode. - /// - /// Asserts: is in iterator mode. - const Element *getNext() { - assert(iteratorLocked && "Attempt to getNext() before startIterator()"); - if (iteratorPos < elements.size()) - return &(elements[iteratorPos++]); - iteratorLocked = false; - return nullptr; - } - private: const std::vector dimSizes; // per-dimension sizes std::vector> elements; // all COO elements std::vector indices; // shared index pool bool isSorted; - bool iteratorLocked; - unsigned iteratorPos; }; } // namespace sparse_tensor diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h @@ -43,8 +43,8 @@ /// This is the "swiss army knife" method for materializing sparse /// tensors into the computation. The types of the `ptr` argument and /// the result depend on the action, as explained in the following table -/// (where "STS" means a sparse-tensor-storage object, and "COO" means -/// a coordinate-scheme object). +/// (where "STS" means a sparse-tensor-storage object, "COO" means +/// a coordinate-scheme object, and "Iterator" means an iterator object). /// /// Action: `ptr`: Returns: /// kEmpty unused STS, empty @@ -53,7 +53,8 @@ /// kFromCOO COO STS, copied from the COO source /// kToCOO STS COO, copied from the STS source /// kSparseToSparse STS STS, copied from the STS source -/// kToIterator STS COO-Iterator, call @getNext to use +/// kToIterator STS Iterator, call @getNext to use and +/// @delSparseTensorIterator to free. MLIR_CRUNNERUTILS_EXPORT void * _mlir_ciface_newSparseTensor(StridedMemRefType *aref, // NOLINT StridedMemRefType *sref, @@ -150,6 +151,12 @@ MLIR_SPARSETENSOR_FOREVERY_V(DECL_DELCOO) #undef DECL_DELCOO +/// Releases the memory for an iterator object. +#define DECL_DELITER(VNAME, V) \ + MLIR_CRUNNERUTILS_EXPORT void delSparseTensorIterator##VNAME(void *iter); +MLIR_SPARSETENSOR_FOREVERY_V(DECL_DELITER) +#undef DECL_DELITER + /// Helper function to read a sparse tensor filename from the environment, /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc. MLIR_CRUNNERUTILS_EXPORT char *getTensorFilename(index_type id); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -319,6 +319,14 @@ createFuncCall(builder, loc, name, {}, coo, EmitCInterface::Off); } +/// Generates a call to release/delete a `SparseTensorIterator`. +static void genDelIteratorCall(OpBuilder &builder, Location loc, Type elemTp, + Value iter) { + SmallString<26> name{"delSparseTensorIterator", + primaryTypeFunctionSuffix(elemTp)}; + createFuncCall(builder, loc, name, {}, iter, EmitCInterface::Off); +} + /// Generates a call that adds one element to a coordinate scheme. /// In particular, this generates code like the following: /// val = a[i1,..,ik]; @@ -335,7 +343,7 @@ /// Generates a call to `iter->getNext()`. If there is a next element, /// then it is copied into the out-parameters `ind` and `elemPtr`, /// and the return value is true. If there isn't a next element, then -/// the memory for `iter` is freed and the return value is false. +/// the return value is false. static Value genGetNextCall(OpBuilder &builder, Location loc, Value iter, Value ind, Value elemPtr) { Type elemTp = elemPtr.getType().cast().getElementType(); @@ -572,7 +580,7 @@ params[7] = coo; Value dst = genNewCall(rewriter, loc, params); genDelCOOCall(rewriter, loc, elemTp, coo); - genDelCOOCall(rewriter, loc, elemTp, iter); + genDelIteratorCall(rewriter, loc, elemTp, iter); rewriter.replaceOp(op, dst); return success(); } @@ -584,6 +592,7 @@ // } // TODO: It can be used by other operators (ReshapeOp, ConvertOP) conversion to // reduce code repetition! +// TODO: rename to `genSparseIterationLoop`? static void genSparseCOOIterationLoop( ConversionPatternRewriter &rewriter, Location loc, Value t, RankedTensorType tensorTp, @@ -624,7 +633,7 @@ rewriter.setInsertionPointAfter(whileOp); // Free memory for iterator. - genDelCOOCall(rewriter, loc, elemTp, iter); + genDelIteratorCall(rewriter, loc, elemTp, iter); } // Generate loop that iterates over a dense tensor. @@ -875,11 +884,11 @@ if (!encDst && encSrc) { // This is sparse => dense conversion, which is handled as follows: // dst = new Tensor(0); - // iter = src->toCOO(); - // iter->startIterator(); + // iter = new SparseTensorIterator(src); // while (elem = iter->getNext()) { // dst[elem.indices] = elem.value; // } + // delete iter; RankedTensorType dstTensorTp = resType.cast(); RankedTensorType srcTensorTp = srcType.cast(); unsigned rank = dstTensorTp.getRank(); @@ -918,7 +927,7 @@ insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, ivs); rewriter.create(loc); rewriter.setInsertionPointAfter(whileOp); - genDelCOOCall(rewriter, loc, elemTp, iter); + genDelIteratorCall(rewriter, loc, elemTp, iter); rewriter.replaceOpWithNewOp(op, resType, dst); // Deallocate the buffer. if (bufferization::allocationDoesNotEscape(op->getOpResult(0))) { diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -68,6 +68,44 @@ namespace { +/// Wrapper class to avoid memory leakage issues. The `SparseTensorCOO` +/// class provides a standard C++ iterator interface, where the iterator +/// is implemented as per `std::vector`'s iterator. However, for MLIR's +/// usage we need to have an iterator which also holds onto the underlying +/// `SparseTensorCOO` so that it can be freed whenever the iterator +/// is freed. +// +// We name this `SparseTensorIterator` rather than `SparseTensorCOOIterator` +// for future-proofing, since the use of `SparseTensorCOO` is an +// implementation detail that we eventually want to change (e.g., to +// use `SparseTensorEnumerator` directly, rather than constructing the +// intermediate `SparseTensorCOO` at all). +template +class SparseTensorIterator final { +public: + /// This ctor requires `coo` to be a non-null pointer to a dynamically + /// allocated object, and takes ownership of that object. Therefore, + /// callers must not free the underlying COO object, since the iterator's + /// dtor will do so. + explicit SparseTensorIterator(const SparseTensorCOO *coo) + : coo(coo), it(coo->begin()), end(coo->end()) {} + + ~SparseTensorIterator() { delete coo; } + + // Disable copy-ctor and copy-assignment, to prevent double-free. + SparseTensorIterator(const SparseTensorIterator &) = delete; + SparseTensorIterator &operator=(const SparseTensorIterator &) = delete; + + /// Gets the next element. If there are no remaining elements, then + /// returns nullptr. + const Element *getNext() { return it < end ? &*it++ : nullptr; } + +private: + const SparseTensorCOO *const coo; // Owning pointer. + typename SparseTensorCOO::const_iterator it; + const typename SparseTensorCOO::const_iterator end; +}; + /// Initializes sparse tensor from an external COO-flavored format. /// Used by `IMPL_CONVERTTOMLIRSPARSETENSOR`. // TODO: generalize beyond 64-bit indices. @@ -194,7 +232,7 @@ return SparseTensorCOO::newSparseTensorCOO(rank, shape, perm); \ coo = static_cast *>(ptr)->toCOO(perm); \ if (action == Action::kToIterator) { \ - coo->startIterator(); \ + return new SparseTensorIterator(coo); \ } else { \ assert(action == Action::kToCOO); \ } \ @@ -398,16 +436,16 @@ #undef IMPL_ADDELT #define IMPL_GETNEXT(VNAME, V) \ - bool _mlir_ciface_getNext##VNAME(void *coo, \ + bool _mlir_ciface_getNext##VNAME(void *iter, \ StridedMemRefType *iref, \ StridedMemRefType *vref) { \ - assert(coo &&iref &&vref); \ + assert(iter &&iref &&vref); \ assert(iref->strides[0] == 1); \ index_type *indx = iref->data + iref->offset; \ V *value = vref->data + vref->offset; \ const uint64_t isize = iref->sizes[0]; \ const Element *elem = \ - static_cast *>(coo)->getNext(); \ + static_cast *>(iter)->getNext(); \ if (elem == nullptr) \ return false; \ for (uint64_t r = 0; r < isize; r++) \ @@ -490,6 +528,13 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_DELCOO) #undef IMPL_DELCOO +#define IMPL_DELITER(VNAME, V) \ + void delSparseTensorIterator##VNAME(void *iter) { \ + delete static_cast *>(iter); \ + } +MLIR_SPARSETENSOR_FOREVERY_V(IMPL_DELITER) +#undef IMPL_DELITER + char *getTensorFilename(index_type id) { char var[80]; sprintf(var, "TENSOR%" PRIu64, id); diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir @@ -58,7 +58,7 @@ // CHECK: memref.store %[[TMP_15]], %[[TMP_0]][%[[TMP_13]], %[[TMP_14]]] : memref<5x4xf64> // CHECK: scf.yield // CHECK: } -// CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr) -> () +// CHECK: call @delSparseTensorIteratorF64(%[[TMP_7]]) : (!llvm.ptr) -> () // CHECK: %[[TMP_11:.*]] = bufferization.to_tensor %[[TMP_0]] : memref<5x4xf64> // CHECK: return %[[TMP_11]] : tensor<5x4xf64> // CHECK: } @@ -141,7 +141,7 @@ // CHECK: %[[TMP_25:.*]] = func.call @addEltF64(%[[TMP_7]], %[[TMP_20]], %[[TMP_10]], %[[TMP_5]]) : (!llvm.ptr, memref, memref, memref) -> !llvm.ptr // CHECK: scf.yield // CHECK: } -// CHECK: call @delSparseTensorCOOF64(%[[TMP_17]]) : (!llvm.ptr) -> () +// CHECK: call @delSparseTensorIteratorF64(%[[TMP_17]]) : (!llvm.ptr) -> () // CHECK: %[[TMP_21:.*]] = call @newSparseTensor(%[[TMP_1]], %[[TMP_3]], %[[TMP_5]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c2_i32]], %[[TMP_7]]) : (memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr // CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr) -> () // CHECK: return %[[TMP_21]] : !llvm.ptr @@ -225,7 +225,7 @@ // CHECK: %[[TMP_25:.*]] = func.call @addEltF64(%[[TMP_7]], %[[TMP_20]], %[[TMP_10]], %[[TMP_5]]) : (!llvm.ptr, memref, memref, memref) -> !llvm.ptr // CHECK: scf.yield // CHECK: } -// CHECK: call @delSparseTensorCOOF64(%[[TMP_17]]) : (!llvm.ptr) -> () +// CHECK: call @delSparseTensorIteratorF64(%[[TMP_17]]) : (!llvm.ptr) -> () // CHECK: %[[TMP_21:.*]] = call @newSparseTensor(%[[TMP_1]], %[[TMP_3]], %[[TMP_5]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c2_i32]], %[[TMP_7]]) : (memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr // CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr) -> () // CHECK: return %[[TMP_21]] : !llvm.ptr @@ -287,7 +287,7 @@ // CHECK: memref.store %[[TMP_15]], %[[TMP_0]][%[[TMP_12]], %[[TMP_14]]] : memref<4x5xf64> // CHECK: scf.yield // CHECK: } -// CHECK: call @delSparseTensorCOOF64(%[[TMP_7]]) : (!llvm.ptr) -> () +// CHECK: call @delSparseTensorIteratorF64(%[[TMP_7]]) : (!llvm.ptr) -> () // CHECK: %[[TMP_11:.*]] = bufferization.to_tensor %[[TMP_0]] : memref<4x5xf64> // CHECK: return %[[TMP_11]] : tensor<4x5xf64> // CHECK: } @@ -348,7 +348,7 @@ // CHECK: memref.store %[[TMP_16]], %[[TMP_0]][%[[TMP_13]], %[[TMP_15]]] : memref<3x5xf64> // CHECK: scf.yield // CHECK: } -// CHECK: call @delSparseTensorCOOF64(%[[TMP_8]]) : (!llvm.ptr) -> () +// CHECK: call @delSparseTensorIteratorF64(%[[TMP_8]]) : (!llvm.ptr) -> () // CHECK: %[[TMP_12:.*]] = bufferization.to_tensor %[[TMP_1]] : memref // CHECK: return %[[TMP_12]] : tensor // CHECK: } diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir @@ -35,7 +35,7 @@ // CHECK-CONV: } // CHECK-CONV: %[[N:.*]] = call @newSparseTensor // CHECK-CONV: call @delSparseTensorCOOF64 -// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: call @delSparseTensorIteratorF64 // CHECK-CONV: return %[[N]] : !llvm.ptr // // rewrite for codegen: @@ -97,7 +97,7 @@ // CHECK-CONV: } // CHECK-CONV: %[[N:.*]] = call @newSparseTensor // CHECK-CONV: call @delSparseTensorCOOF64 -// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: call @delSparseTensorIteratorF64 // CHECK-CONV: return %[[N]] : !llvm.ptr // // rewrite for codegen: @@ -172,7 +172,7 @@ // CHECK-CONV: } // CHECK-CONV: %[[N:.*]] = call @newSparseTensor // CHECK-CONV: call @delSparseTensorCOOF64 -// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: call @delSparseTensorIteratorF64 // CHECK-CONV: return %[[N]] : !llvm.ptr // // rewrite for codegen: @@ -244,7 +244,7 @@ // CHECK-CONV: } // CHECK-CONV: %[[N:.*]] = call @newSparseTensor // CHECK-CONV: call @delSparseTensorCOOF64 -// CHECK-CONV: call @delSparseTensorCOOF64 +// CHECK-CONV: call @delSparseTensorIteratorF64 // CHECK-CONV: return %[[N]] : !llvm.ptr // // rewrite for codegen: