diff --git a/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h b/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h --- a/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h +++ b/mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h @@ -9,6 +9,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include namespace mlir { /// A 2D array where each row may have different length. Elements of each row @@ -25,15 +26,92 @@ /// Accesses `pos`-th row. ArrayRef operator[](size_t pos) const { return at(pos); } - ArrayRef at(size_t pos) const { return slices[pos]; } + ArrayRef at(size_t pos) const { + if (slices[pos].first == static_cast(-1)) + return ArrayRef(); + return ArrayRef(storage).slice(slices[pos].first, slices[pos].second); + } MutableArrayRef operator[](size_t pos) { return at(pos); } - MutableArrayRef at(size_t pos) { return slices[pos]; } + MutableArrayRef at(size_t pos) { + if (slices[pos].first == static_cast(-1)) + return MutableArrayRef(); + return MutableArrayRef(storage).slice(slices[pos].first, + slices[pos].second); + } + + /// Iterator over the rows. + class iterator + : public llvm::iterator_facade_base< + iterator, std::forward_iterator_tag, MutableArrayRef, + std::ptrdiff_t, MutableArrayRef *, MutableArrayRef> { + public: + /// Creates the start iterator. + explicit iterator(RaggedArray &ragged) : ragged(ragged), pos(0) {} + + /// Creates the end iterator. + iterator(RaggedArray &ragged, size_t pos) : ragged(ragged), pos(pos) {} + + /// Dereferences the current iterator. Assumes in-bounds. + MutableArrayRef operator*() const { return ragged[pos]; } + + /// Increments the iterator. + iterator &operator++() { + if (pos < ragged.slices.size()) + ++pos; + return *this; + } + + /// Compares the two iterators. Iterators into different ragged arrays + /// compare not equal. + bool operator==(const iterator &other) const { + return &ragged == &other.ragged && pos == other.pos; + } + + private: + RaggedArray &ragged; + size_t pos; + }; + + /// Constant iterator over the rows. + class const_iterator + : public llvm::iterator_facade_base< + const_iterator, std::forward_iterator_tag, ArrayRef, + std::ptrdiff_t, ArrayRef *, ArrayRef> { + public: + /// Creates the start iterator. + explicit const_iterator(const RaggedArray &ragged) + : ragged(ragged), pos(0) {} + + /// Creates the end iterator. + const_iterator(const RaggedArray &ragged, size_t pos) + : ragged(ragged), pos(pos) {} + + /// Dereferences the current iterator. Assumes in-bounds. + ArrayRef operator*() const { return ragged[pos]; } + + /// Increments the iterator. + const_iterator &operator++() { + if (pos < ragged.slices.size()) + ++pos; + return *this; + } + + /// Compares the two iterators. Iterators into different ragged arrays + /// compare not equal. + bool operator==(const const_iterator &other) const { + return &ragged == &other.ragged && pos == other.pos; + } + + private: + const RaggedArray &ragged; + size_t pos; + }; /// Iterator over rows. - auto begin() { return slices.begin(); } - auto begin() const { return slices.begin(); } - auto end() { return slices.end(); } - auto end() const { return slices.end(); } + const_iterator begin() const { return const_iterator(*this); } + const_iterator end() const { return const_iterator(*this, slices.size()); } + iterator begin() { return iterator(*this); } + iterator end() { return iterator(*this, slices.size()); } /// Reserve space to store `size` rows with `nestedSize` elements each. void reserve(size_t size, size_t nestedSize = 0) { @@ -53,38 +131,41 @@ /// succeeding rows. template void replace(size_t pos, Range &&elements) { - auto from = slices[pos].data(); - if (from != nullptr) { - auto to = std::next(from, slices[pos].size()); + if (slices[pos].first != static_cast(-1)) { + auto from = std::next(storage.begin(), slices[pos].first); + auto to = std::next(from, slices[pos].second); auto newFrom = storage.erase(from, to); // Update the array refs after the underlying storage was shifted. for (size_t i = pos + 1, e = size(); i < e; ++i) { - slices[i] = MutableArrayRef(newFrom, slices[i].size()); - std::advance(newFrom, slices[i].size()); + slices[i] = std::make_pair(std::distance(storage.begin(), newFrom), + slices[i].second); + std::advance(newFrom, slices[i].second); } } slices[pos] = appendToStorage(std::forward(elements)); } /// Appends `num` empty rows to the array. - void appendEmptyRows(size_t num) { slices.resize(slices.size() + num); } + void appendEmptyRows(size_t num) { + slices.resize(slices.size() + num, std::pair(-1, 0)); + } private: - /// Appends the given elements to the storage and returns an ArrayRef pointing - /// to them in the storage. + /// Appends the given elements to the storage and returns an ArrayRef + /// pointing to them in the storage. template - MutableArrayRef appendToStorage(Range &&elements) { + std::pair appendToStorage(Range &&elements) { size_t start = storage.size(); llvm::append_range(storage, std::forward(elements)); - return MutableArrayRef(storage).drop_front(start); + return std::make_pair(start, storage.size() - start); } - /// Outer elements of the ragged array. Each entry is a reference to a - /// contiguous segment in the `storage` list that contains the actual - /// elements. This allows for elements to be stored contiguously without - /// nested vectors and for different segments to be set or replaced in any - /// order. - SmallVector> slices; + /// Outer elements of the ragged array. Each entry is an (offset, length) + /// pair identifying a contiguous segment in the `storage` list that + /// contains the actual elements. This allows for elements to be stored + /// contiguously without nested vectors and for different segments to be set + /// or replaced in any order. + SmallVector> slices; /// Dense storage for ragged array elements. SmallVector storage;