diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -598,8 +598,8 @@ iterator_range getTypes() const { return {begin(), end()}; } private: - /// See `detail::indexed_accessor_range_base` for details. - static OpResult dereference_iterator(Operation *op, ptrdiff_t index); + /// See `indexed_accessor_range` for details. + static OpResult dereference(Operation *op, ptrdiff_t index); /// Allow access to `dereference_iterator`. friend indexed_accessor_range &range) : indexed_accessor_range_base(range.begin(), range.end()) {} + indexed_accessor_range_base(BaseT base, ptrdiff_t count) + : base(base), count(count) {} iterator begin() const { return iterator(base, 0); } iterator end() const { return iterator(base, count); } @@ -267,8 +269,6 @@ } protected: - indexed_accessor_range_base(BaseT base, ptrdiff_t count) - : base(base), count(count) {} indexed_accessor_range_base(const indexed_accessor_range_base &) = default; indexed_accessor_range_base(indexed_accessor_range_base &&) = default; indexed_accessor_range_base & @@ -286,18 +286,20 @@ /// bases that are offsetable should derive from indexed_accessor_range_base /// instead. Derived range classes are expected to implement the following /// static method: -/// * ReferenceT dereference_iterator(const BaseT &base, ptrdiff_t index) +/// * ReferenceT dereference(const BaseT &base, ptrdiff_t index) /// - Derefence an iterator pointing to a parent base at the given index. template class indexed_accessor_range : public detail::indexed_accessor_range_base< - indexed_accessor_range, - std::pair, T, PointerT, ReferenceT> { + DerivedT, std::pair, T, PointerT, ReferenceT> { public: + indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count) + : detail::indexed_accessor_range_base< + DerivedT, std::pair, T, PointerT, ReferenceT>( + std::make_pair(base, startIndex), count) {} using detail::indexed_accessor_range_base< - indexed_accessor_range, - std::pair, T, PointerT, + DerivedT, std::pair, T, PointerT, ReferenceT>::indexed_accessor_range_base; /// Returns the current base of the range. @@ -306,14 +308,6 @@ /// Returns the current start index of the range. ptrdiff_t getStartIndex() const { return this->base.second; } -protected: - indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count) - : detail::indexed_accessor_range_base< - indexed_accessor_range, - std::pair, T, PointerT, ReferenceT>( - std::make_pair(base, startIndex), count) {} - -private: /// See `detail::indexed_accessor_range_base` for details. static std::pair offset_base(const std::pair &base, ptrdiff_t index) { @@ -325,13 +319,8 @@ static ReferenceT dereference_iterator(const std::pair &base, ptrdiff_t index) { - return DerivedT::dereference_iterator(base.first, base.second + index); + return DerivedT::dereference(base.first, base.second + index); } - - /// Allow access to `offset_base` and `dereference_iterator`. - friend detail::indexed_accessor_range_base< - indexed_accessor_range, - std::pair, T, PointerT, ReferenceT>; }; /// Given a container of pairs, return a range over the second elements. diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -152,8 +152,8 @@ ResultRange::ResultRange(Operation *op) : ResultRange(op, /*startIndex=*/0, op->getNumResults()) {} -/// See `detail::indexed_accessor_range_base` for details. -OpResult ResultRange::dereference_iterator(Operation *op, ptrdiff_t index) { +/// See `indexed_accessor_range` for details. +OpResult ResultRange::dereference(Operation *op, ptrdiff_t index) { return op->getResult(index); } diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -10,4 +10,5 @@ add_subdirectory(IR) add_subdirectory(Pass) add_subdirectory(SDBM) +add_subdirectory(Support) add_subdirectory(TableGen) diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Support/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_unittest(MLIRSupportTests + IndexedAccessorTest.cpp +) + +target_link_libraries(MLIRSupportTests + PRIVATE MLIRSupport) diff --git a/mlir/unittests/Support/IndexedAccessorTest.cpp b/mlir/unittests/Support/IndexedAccessorTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Support/IndexedAccessorTest.cpp @@ -0,0 +1,49 @@ +//===- IndexedAccessorTest.cpp - Indexed Accessor Tests -------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/ArrayRef.h" +#include "gmock/gmock.h" + +using namespace mlir; +using namespace mlir::detail; + +namespace { +/// Simple indexed accessor range that wraps an array. +template +struct ArrayIndexedAccessorRange + : public indexed_accessor_range, T *, T> { + ArrayIndexedAccessorRange(T *data, ptrdiff_t start, ptrdiff_t numElements) + : indexed_accessor_range, T *, T>( + data, start, numElements) {} + using indexed_accessor_range, T *, + T>::indexed_accessor_range; + + /// See `indexed_accessor_range` for details. + static T &dereference(T *data, ptrdiff_t index) { return data[index]; } +}; +} // end anonymous namespace + +template +static void compareData(ArrayIndexedAccessorRange range, + ArrayRef referenceData) { + ASSERT_TRUE(referenceData.size() == range.size()); + ASSERT_TRUE(std::equal(range.begin(), range.end(), referenceData.begin())); +} + +namespace { +TEST(AccessorRange, SliceTest) { + int rawData[] = {0, 1, 2, 3, 4}; + ArrayRef data = llvm::makeArrayRef(rawData); + + ArrayIndexedAccessorRange range(rawData, /*start=*/0, /*numElements=*/5); + compareData(range, data); + compareData(range.slice(2, 3), data.slice(2, 3)); + compareData(range.slice(0, 5), data.slice(0, 5)); +} +} // end anonymous namespace