diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -307,6 +307,32 @@ return make_range(map_iterator(C.begin(), F), map_iterator(C.end(), F)); } +/// A base type of mapped iterator, that is useful for building derived +/// iterators that do not need/want to store the map function (as in +/// mapped_iterator). These iterators must simply provide a `mapElement` method +/// that defines how to map a value of the iterator to the provided reference +/// type. +template +class mapped_iterator_base + : public iterator_adaptor_base< + DerivedT, ItTy, + typename std::iterator_traits::iterator_category, + std::remove_reference_t, + typename std::iterator_traits::difference_type, + std::remove_reference_t *, ReferenceTy> { +public: + using BaseT = mapped_iterator_base; + + mapped_iterator_base(ItTy U) + : mapped_iterator_base::iterator_adaptor_base(std::move(U)) {} + + ItTy getCurrent() { return this->I; } + + ReferenceTy operator*() const { + return static_cast(*this).mapElement(*this->I); + } +}; + /// Helper to determine if type T has a member called rbegin(). template class has_rbegin_impl { using yes = char[1]; diff --git a/llvm/unittests/ADT/MappedIteratorTest.cpp b/llvm/unittests/ADT/MappedIteratorTest.cpp --- a/llvm/unittests/ADT/MappedIteratorTest.cpp +++ b/llvm/unittests/ADT/MappedIteratorTest.cpp @@ -47,4 +47,67 @@ EXPECT_EQ(M[1], 42) << "assignment should have modified M"; } +TEST(MappedIteratorTest, CustomIteratorApplyFunctionOnDereference) { + struct CustomMapIterator + : public llvm::mapped_iterator_base::iterator, int> { + using BaseT::BaseT; + + /// Map the element to the iterator result type. + int mapElement(int X) const { return X + 1; } + }; + + std::vector V({0}); + + CustomMapIterator I(V.begin()); + + EXPECT_EQ(*I, 1) << "should have applied function in dereference"; +} + +TEST(MappedIteratorTest, CustomIteratorApplyFunctionOnArrow) { + struct S { + int Z = 0; + }; + struct CustomMapIterator + : public llvm::mapped_iterator_base::iterator, S &> { + CustomMapIterator(std::vector::iterator it, S *P) : BaseT(it), P(P) {} + + /// Map the element to the iterator result type. + S &mapElement(int X) const { return *(P + X); } + + S *P; + }; + + std::vector V({0}); + S Y; + + CustomMapIterator I(V.begin(), &Y); + + I->Z = 42; + + EXPECT_EQ(Y.Z, 42) << "should have applied function during arrow"; +} + +TEST(MappedIteratorTest, CustomIteratorFunctionPreservesReferences) { + struct CustomMapIterator + : public llvm::mapped_iterator_base::iterator, int &> { + CustomMapIterator(std::vector::iterator it, std::map &M) + : BaseT(it), M(M) {} + + /// Map the element to the iterator result type. + int &mapElement(int X) const { return M[X]; } + + std::map &M; + }; + std::vector V({1}); + std::map M({{1, 1}}); + + auto I = CustomMapIterator(V.begin(), M); + *I = 42; + + EXPECT_EQ(M[1], 42) << "assignment should have modified M"; +} + } // anonymous namespace diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -323,24 +323,46 @@ /// Iterator for walking over APFloat values. class FloatElementIterator final - : public llvm::mapped_iterator> { + : public llvm::mapped_iterator_base { + public: + /// Map the element to the iterator result type. + APFloat mapElement(const APInt &value) const { + return APFloat(*smt, value); + } + + private: friend DenseElementsAttr; /// Initializes the float element iterator to the specified iterator. - FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it); + FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it) + : BaseT(it), smt(&smt) {} + + /// The float semantics to use when constructing the APFloat. + const llvm::fltSemantics *smt; }; /// Iterator for walking over complex APFloat values. class ComplexFloatElementIterator final - : public llvm::mapped_iterator< - ComplexIntElementIterator, - std::function(const std::complex &)>> { + : public llvm::mapped_iterator_base> { + public: + /// Map the element to the iterator result type. + std::complex mapElement(const std::complex &value) const { + return {APFloat(*smt, value.real()), APFloat(*smt, value.imag())}; + } + + private: friend DenseElementsAttr; /// Initializes the float element iterator to the specified iterator. ComplexFloatElementIterator(const llvm::fltSemantics &smt, - ComplexIntElementIterator it); + ComplexIntElementIterator it) + : BaseT(it), smt(&smt) {} + + /// The float semantics to use when constructing the APFloat. + const llvm::fltSemantics *smt; }; //===--------------------------------------------------------------------===// @@ -478,24 +500,27 @@ typename std::enable_if::value && !std::is_same::value>::type; template - using DerivedAttributeElementIterator = - llvm::mapped_iterator; + struct DerivedAttributeElementIterator + : public llvm::mapped_iterator_base, + AttributeElementIterator, T> { + using DerivedAttributeElementIterator::BaseT::BaseT; + + /// Map the element to the iterator result type. + T mapElement(Attribute attr) const { return attr.cast(); } + }; template > iterator_range_impl> getValues() const { - auto castFn = [](Attribute attr) { return attr.template cast(); }; - return {Attribute::getType(), - llvm::map_range(getValues(), - static_cast(castFn))}; + using DerivedIterT = DerivedAttributeElementIterator; + return {Attribute::getType(), DerivedIterT(value_begin()), + DerivedIterT(value_end())}; } template > DerivedAttributeElementIterator value_begin() const { - auto castFn = [](Attribute attr) { return attr.template cast(); }; - return {value_begin(), static_cast(castFn)}; + return {value_begin()}; } template > DerivedAttributeElementIterator value_end() const { - auto castFn = [](Attribute attr) { return attr.template cast(); }; - return {value_end(), static_cast(castFn)}; + return {value_end()}; } /// Return the held element values as a range of bool. The element type of diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -155,20 +155,6 @@ class Diagnostic { using NoteVector = std::vector>; - /// This class implements a wrapper iterator around NoteVector::iterator to - /// implicitly dereference the unique_ptr. - template - class NoteIteratorImpl - : public llvm::mapped_iterator { - static ResultTy &unwrap(NotePtrTy note) { return *note; } - - public: - NoteIteratorImpl(IteratorTy it) - : llvm::mapped_iterator(it, - &unwrap) {} - }; - public: Diagnostic(Location loc, DiagnosticSeverity severity) : loc(loc), severity(severity) {} @@ -262,15 +248,16 @@ /// diagnostic. Notes may not be attached to other notes. Diagnostic &attachNote(Optional noteLoc = llvm::None); - using note_iterator = NoteIteratorImpl; - using const_note_iterator = NoteIteratorImpl; + using note_iterator = llvm::pointee_iterator; + using const_note_iterator = + llvm::pointee_iterator; /// Returns the notes held by this diagnostic. iterator_range getNotes() { - return {notes.begin(), notes.end()}; + return llvm::make_pointee_range(notes); } iterator_range getNotes() const { - return {notes.begin(), notes.end()}; + return llvm::make_pointee_range(notes); } /// Allow a diagnostic to be converted to 'failure'. diff --git a/mlir/include/mlir/IR/DialectInterface.h b/mlir/include/mlir/IR/DialectInterface.h --- a/mlir/include/mlir/IR/DialectInterface.h +++ b/mlir/include/mlir/IR/DialectInterface.h @@ -111,20 +111,16 @@ /// An iterator class that iterates the held interface objects of the given /// derived interface type. template - class iterator : public llvm::mapped_iterator< - InterfaceVectorT::const_iterator, - const InterfaceT &(*)(const DialectInterface *)> { - static const InterfaceT &remapIt(const DialectInterface *interface) { + struct iterator + : public llvm::mapped_iterator_base, + InterfaceVectorT::const_iterator, + const InterfaceT &> { + using iterator::BaseT::BaseT; + + /// Map the element to the iterator result type. + const InterfaceT &mapElement(const DialectInterface *interface) const { return *static_cast(interface); } - - iterator(InterfaceVectorT::const_iterator it) - : llvm::mapped_iterator< - InterfaceVectorT::const_iterator, - const InterfaceT &(*)(const DialectInterface *)>(it, &remapIt) {} - - /// Allow access to the constructor. - friend DialectInterfaceCollectionBase; }; /// Iterator access to the held interfaces. diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h --- a/mlir/include/mlir/IR/TypeRange.h +++ b/mlir/include/mlir/IR/TypeRange.h @@ -124,16 +124,13 @@ /// This class implements iteration on the types of a given range of values. template class ValueTypeIterator final - : public llvm::mapped_iterator { - static Type unwrap(Value value) { return value.getType(); } - + : public llvm::mapped_iterator_base, + ValueIteratorT, Type> { public: - /// Provide a const dereference method. - Type operator*() const { return unwrap(*this->I); } + using ValueTypeIterator::BaseT::BaseT; - /// Initializes the type iterator to the specified value iterator. - ValueTypeIterator(ValueIteratorT it) - : llvm::mapped_iterator(it, &unwrap) {} + /// Map the element to the iterator result type. + Type mapElement(Value value) const { return value.getType(); } }; /// This class implements iteration on the types of a given range of values. diff --git a/mlir/include/mlir/IR/TypeUtilities.h b/mlir/include/mlir/IR/TypeUtilities.h --- a/mlir/include/mlir/IR/TypeUtilities.h +++ b/mlir/include/mlir/IR/TypeUtilities.h @@ -66,36 +66,33 @@ /// Dimensions are compatible if all non-dynamic dims are equal. LogicalResult verifyCompatibleDims(ArrayRef dims); + //===----------------------------------------------------------------------===// // Utility Iterators //===----------------------------------------------------------------------===// // An iterator for the element types of an op's operands of shaped types. class OperandElementTypeIterator final - : public llvm::mapped_iterator { + : public llvm::mapped_iterator_base { public: - /// Initializes the result element type iterator to the specified operand - /// iterator. - explicit OperandElementTypeIterator(Operation::operand_iterator it); + using BaseT::BaseT; -private: - static Type unwrap(Value value); + /// Map the element to the iterator result type. + Type mapElement(Value value) const; }; using OperandElementTypeRange = iterator_range; // An iterator for the tensor element types of an op's results of shaped types. class ResultElementTypeIterator final - : public llvm::mapped_iterator { + : public llvm::mapped_iterator_base { public: - /// Initializes the result element type iterator to the specified result - /// iterator. - explicit ResultElementTypeIterator(Operation::result_iterator it); + using BaseT::BaseT; -private: - static Type unwrap(Value value); + /// Map the element to the iterator result type. + Type mapElement(Value value) const; }; using ResultElementTypeRange = iterator_range; diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -281,15 +281,16 @@ /// a specific use iterator. template class ValueUserIterator final - : public llvm::mapped_iterator { - static Operation *unwrap(OperandType &value) { return value.getOwner(); } - + : public llvm::mapped_iterator_base< + ValueUserIterator, UseIteratorT, + Operation *> { public: - /// Initializes the user iterator to the specified use iterator. - ValueUserIterator(UseIteratorT it) - : llvm::mapped_iterator( - it, &unwrap) {} + using ValueUserIterator::BaseT::BaseT; + + /// Map the element to the iterator result type. + Operation *mapElement(OperandType &value) const { return value.getOwner(); } + + /// Provide access to the underlying operation. Operation *operator->() { return **this; } }; diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -667,27 +667,6 @@ readBits(getData(), offset + storageWidth, bitWidth)}; } -//===----------------------------------------------------------------------===// -// FloatElementIterator - -DenseElementsAttr::FloatElementIterator::FloatElementIterator( - const llvm::fltSemantics &smt, IntElementIterator it) - : llvm::mapped_iterator>( - it, [&](const APInt &val) { return APFloat(smt, val); }) {} - -//===----------------------------------------------------------------------===// -// ComplexFloatElementIterator - -DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( - const llvm::fltSemantics &smt, ComplexIntElementIterator it) - : llvm::mapped_iterator< - ComplexIntElementIterator, - std::function(const std::complex &)>>( - it, [&](const std::complex &val) -> std::complex { - return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; - }) {} - //===----------------------------------------------------------------------===// // DenseElementsAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -151,20 +151,10 @@ return success(); } -OperandElementTypeIterator::OperandElementTypeIterator( - Operation::operand_iterator it) - : llvm::mapped_iterator( - it, &unwrap) {} - -Type OperandElementTypeIterator::unwrap(Value value) { +Type OperandElementTypeIterator::mapElement(Value value) const { return value.getType().cast().getElementType(); } -ResultElementTypeIterator::ResultElementTypeIterator( - Operation::result_iterator it) - : llvm::mapped_iterator( - it, &unwrap) {} - -Type ResultElementTypeIterator::unwrap(Value value) { +Type ResultElementTypeIterator::mapElement(Value value) const { return value.getType().cast().getElementType(); }