diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -271,25 +271,38 @@ // mapped_iterator - This is a simple iterator adapter that causes a function to // be applied whenever operator* is invoked on the iterator. -template ()(*std::declval()))> -class mapped_iterator +template +class mapped_iterator_base : public iterator_adaptor_base< - mapped_iterator, ItTy, + DerivedT, ItTy, typename std::iterator_traits::iterator_category, std::remove_reference_t, typename std::iterator_traits::difference_type, std::remove_reference_t *, ReferenceTy> { public: - mapped_iterator(ItTy U, FuncTy F) - : mapped_iterator::iterator_adaptor_base(std::move(U)), F(std::move(F)) {} + using BaseT = mapped_iterator_base; + + mapped_iterator_base(ItTy U) + : mapped_iterator_base::iterator_adaptor_base(std::move(U)) {} ItTy getCurrent() { return this->I; } - const FuncTy &getFunction() const { return F; } + ReferenceTy operator*() const { + return static_cast(*this).getFunction()(*this->I); + } +}; - ReferenceTy operator*() const { return F(*this->I); } +template ()(*std::declval()))> +class mapped_iterator + : public mapped_iterator_base, + ItTy, ReferenceTy> { +public: + mapped_iterator(ItTy U, FuncTy F) + : mapped_iterator::BaseT(std::move(U)), F(std::move(F)) {} + + const FuncTy &getFunction() const { return F; } private: FuncTy F; diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -323,24 +323,48 @@ /// Iterator for walking over APFloat values. class FloatElementIterator final - : public llvm::mapped_iterator> { + : public llvm::mapped_iterator_base { + public: + /// Return the map function used by this iterator. + auto getFunction() const { + return [&](const APInt &intValue) { return APFloat(*smt, intValue); }; + } + + private: friend DenseElementsAttr; /// Initializes the float element iterator to the specified iterator. - FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it); + FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it) + : BaseT(it), smt(&smt) {} + + /// The float semantics to use when constructing the APFloat. + const llvm::fltSemantics *smt; }; /// Iterator for walking over complex APFloat values. class ComplexFloatElementIterator final - : public llvm::mapped_iterator< - ComplexIntElementIterator, - std::function(const std::complex &)>> { + : public llvm::mapped_iterator_base> { + public: + /// Return the map function used by this iterator. + auto getFunction() const { + return [&](const std::complex &intValue) -> std::complex { + return {APFloat(*smt, intValue.real()), APFloat(*smt, intValue.imag())}; + }; + } + + private: friend DenseElementsAttr; /// Initializes the float element iterator to the specified iterator. ComplexFloatElementIterator(const llvm::fltSemantics &smt, - ComplexIntElementIterator it); + ComplexIntElementIterator it) + : BaseT(it), smt(&smt) {} + + /// The float semantics to use when constructing the APFloat. + const llvm::fltSemantics *smt; }; //===--------------------------------------------------------------------===// @@ -478,24 +502,29 @@ typename std::enable_if::value && !std::is_same::value>::type; template - using DerivedAttributeElementIterator = - llvm::mapped_iterator; + struct DerivedAttributeElementIterator + : public llvm::mapped_iterator_base, + AttributeElementIterator, T> { + using DerivedAttributeElementIterator::BaseT::BaseT; + + /// Return the map function used by this iterator. + auto getFunction() const { + return [](Attribute attr) { return attr.cast(); }; + } + }; template > iterator_range_impl> getValues() const { - auto castFn = [](Attribute attr) { return attr.template cast(); }; - return {Attribute::getType(), - llvm::map_range(getValues(), - static_cast(castFn))}; + using DerivedIterT = DerivedAttributeElementIterator; + return {Attribute::getType(), DerivedIterT(value_begin()), + DerivedIterT(value_end())}; } template > DerivedAttributeElementIterator value_begin() const { - auto castFn = [](Attribute attr) { return attr.template cast(); }; - return {value_begin(), static_cast(castFn)}; + return {value_begin()}; } template > DerivedAttributeElementIterator value_end() const { - auto castFn = [](Attribute attr) { return attr.template cast(); }; - return {value_end(), static_cast(castFn)}; + return {value_end()}; } /// Return the held element values as a range of bool. The element type of diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -156,20 +156,6 @@ class Diagnostic { using NoteVector = std::vector>; - /// This class implements a wrapper iterator around NoteVector::iterator to - /// implicitly dereference the unique_ptr. - template - class NoteIteratorImpl - : public llvm::mapped_iterator { - static ResultTy &unwrap(NotePtrTy note) { return *note; } - - public: - NoteIteratorImpl(IteratorTy it) - : llvm::mapped_iterator(it, - &unwrap) {} - }; - public: Diagnostic(Location loc, DiagnosticSeverity severity) : loc(loc), severity(severity) {} @@ -265,15 +251,16 @@ /// diagnostic. Notes may not be attached to other notes. Diagnostic &attachNote(Optional noteLoc = llvm::None); - using note_iterator = NoteIteratorImpl; - using const_note_iterator = NoteIteratorImpl; + using note_iterator = llvm::pointee_iterator; + using const_note_iterator = + llvm::pointee_iterator; /// Returns the notes held by this diagnostic. iterator_range getNotes() { - return {notes.begin(), notes.end()}; + return llvm::make_pointee_range(notes); } iterator_range getNotes() const { - return {notes.begin(), notes.end()}; + return llvm::make_pointee_range(notes); } /// Allow a diagnostic to be converted to 'failure'. diff --git a/mlir/include/mlir/IR/DialectInterface.h b/mlir/include/mlir/IR/DialectInterface.h --- a/mlir/include/mlir/IR/DialectInterface.h +++ b/mlir/include/mlir/IR/DialectInterface.h @@ -111,20 +111,18 @@ /// An iterator class that iterates the held interface objects of the given /// derived interface type. template - class iterator : public llvm::mapped_iterator< - InterfaceVectorT::const_iterator, - const InterfaceT &(*)(const DialectInterface *)> { - static const InterfaceT &remapIt(const DialectInterface *interface) { - return *static_cast(interface); + struct iterator + : public llvm::mapped_iterator_base, + InterfaceVectorT::const_iterator, + const InterfaceT &> { + using iterator::BaseT::BaseT; + + /// Return the map function used by this iterator. + auto getFunction() const { + return [](const DialectInterface *interface) -> const InterfaceT & { + return *static_cast(interface); + }; } - - iterator(InterfaceVectorT::const_iterator it) - : llvm::mapped_iterator< - InterfaceVectorT::const_iterator, - const InterfaceT &(*)(const DialectInterface *)>(it, &remapIt) {} - - /// Allow access to the constructor. - friend DialectInterfaceCollectionBase; }; /// Iterator access to the held interfaces. diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h --- a/mlir/include/mlir/IR/TypeRange.h +++ b/mlir/include/mlir/IR/TypeRange.h @@ -124,16 +124,18 @@ /// This class implements iteration on the types of a given range of values. template class ValueTypeIterator final - : public llvm::mapped_iterator { - static Type unwrap(Value value) { return value.getType(); } - + : public llvm::mapped_iterator_base, + ValueIteratorT, Type> { public: - /// Provide a const dereference method. - Type operator*() const { return unwrap(*this->I); } + using ValueTypeIterator::BaseT::BaseT; - /// Initializes the type iterator to the specified value iterator. - ValueTypeIterator(ValueIteratorT it) - : llvm::mapped_iterator(it, &unwrap) {} + /// Return the map function used by this iterator. + auto getFunction() const { + return [](Value value) { return value.getType(); }; + } + + /// Provide a const dereference method. + Type operator*() const { return getFunction()(*this->I); } }; /// This class implements iteration on the types of a given range of values. diff --git a/mlir/include/mlir/IR/TypeUtilities.h b/mlir/include/mlir/IR/TypeUtilities.h --- a/mlir/include/mlir/IR/TypeUtilities.h +++ b/mlir/include/mlir/IR/TypeUtilities.h @@ -66,36 +66,33 @@ /// Dimensions are compatible if all non-dynamic dims are equal. LogicalResult verifyCompatibleDims(ArrayRef dims); + //===----------------------------------------------------------------------===// // Utility Iterators //===----------------------------------------------------------------------===// // An iterator for the element types of an op's operands of shaped types. class OperandElementTypeIterator final - : public llvm::mapped_iterator { + : public llvm::mapped_iterator_base { public: - /// Initializes the result element type iterator to the specified operand - /// iterator. - explicit OperandElementTypeIterator(Operation::operand_iterator it); + using BaseT::BaseT; -private: - static Type unwrap(Value value); + /// Return the map function used by this iterator. + function_ref getFunction() const; }; using OperandElementTypeRange = iterator_range; // An iterator for the tensor element types of an op's results of shaped types. class ResultElementTypeIterator final - : public llvm::mapped_iterator { + : public llvm::mapped_iterator_base { public: - /// Initializes the result element type iterator to the specified result - /// iterator. - explicit ResultElementTypeIterator(Operation::result_iterator it); + using BaseT::BaseT; -private: - static Type unwrap(Value value); + /// Return the map function used by this iterator. + function_ref getFunction() const; }; using ResultElementTypeRange = iterator_range; diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -281,15 +281,18 @@ /// a specific use iterator. template class ValueUserIterator final - : public llvm::mapped_iterator { - static Operation *unwrap(OperandType &value) { return value.getOwner(); } - + : public llvm::mapped_iterator_base< + ValueUserIterator, UseIteratorT, + Operation *> { public: - /// Initializes the user iterator to the specified use iterator. - ValueUserIterator(UseIteratorT it) - : llvm::mapped_iterator( - it, &unwrap) {} + using ValueUserIterator::BaseT::BaseT; + + /// Return the map function used by this iterator. + auto getFunction() const { + return [](OperandType &value) { return value.getOwner(); }; + } + + /// Provide access to the underlying operation. Operation *operator->() { return **this; } }; diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -661,27 +661,6 @@ readBits(getData(), offset + storageWidth, bitWidth)}; } -//===----------------------------------------------------------------------===// -// FloatElementIterator - -DenseElementsAttr::FloatElementIterator::FloatElementIterator( - const llvm::fltSemantics &smt, IntElementIterator it) - : llvm::mapped_iterator>( - it, [&](const APInt &val) { return APFloat(smt, val); }) {} - -//===----------------------------------------------------------------------===// -// ComplexFloatElementIterator - -DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( - const llvm::fltSemantics &smt, ComplexIntElementIterator it) - : llvm::mapped_iterator< - ComplexIntElementIterator, - std::function(const std::complex &)>>( - it, [&](const std::complex &val) -> std::complex { - return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; - }) {} - //===----------------------------------------------------------------------===// // DenseElementsAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -151,20 +151,14 @@ return success(); } -OperandElementTypeIterator::OperandElementTypeIterator( - Operation::operand_iterator it) - : llvm::mapped_iterator( - it, &unwrap) {} - -Type OperandElementTypeIterator::unwrap(Value value) { - return value.getType().cast().getElementType(); +function_ref OperandElementTypeIterator::getFunction() const { + return [](Value value) { + return value.getType().cast().getElementType(); + }; } -ResultElementTypeIterator::ResultElementTypeIterator( - Operation::result_iterator it) - : llvm::mapped_iterator( - it, &unwrap) {} - -Type ResultElementTypeIterator::unwrap(Value value) { - return value.getType().cast().getElementType(); +function_ref ResultElementTypeIterator::getFunction() const { + return [](Value value) { + return value.getType().cast().getElementType(); + }; }