diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -188,10 +188,9 @@ return p << (value ? StringRef("true") : "false"); } -template -inline OpAsmPrinter & -operator<<(OpAsmPrinter &p, - const iterator_range> &types) { +template +inline OpAsmPrinter &operator<<(OpAsmPrinter &p, + const ValueTypeRange &types) { interleaveComma(types, p); return p; } diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -232,7 +232,7 @@ // Support operand type iteration. using operand_type_iterator = operand_range::type_iterator; - using operand_type_range = iterator_range; + using operand_type_range = operand_range::type_range; operand_type_iterator operand_type_begin() { return operand_begin(); } operand_type_iterator operand_type_end() { return operand_end(); } operand_type_range getOperandTypes() { return getOperands().getTypes(); } @@ -260,7 +260,7 @@ /// Support result type iteration. using result_type_iterator = result_range::type_iterator; - using result_type_range = ArrayRef; + using result_type_range = result_range::type_range; result_type_iterator result_type_begin() { return getResultTypes().begin(); } result_type_iterator result_type_end() { return getResultTypes().end(); } result_type_range getResultTypes(); diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -32,14 +32,17 @@ class OpAsmParser; class OpAsmParserResult; class OpAsmPrinter; +class OperandRange; class OpFoldResult; class ParseResult; class Pattern; class Region; +class ResultRange; class RewritePattern; class Type; class Value; class ValueRange; +template class ValueTypeRange; /// This is an adaptor from a list of values to named operands of OpTy. In a /// generic operation context, e.g., in dialect conversions, an ordered array of @@ -536,6 +539,46 @@ //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// +// TypeRange + +/// This class provides an abstraction over the various different ranges of +/// value types. In many cases, this prevents the need to explicitly materialize +/// a SmallVector/std::vector. This class should be used in places that are not +/// suitable for a more derived type (e.g. ArrayRef) or a template range +/// parameter. +class TypeRange + : public detail::indexed_accessor_range_base< + TypeRange, + llvm::PointerUnion, Type, + Type, Type> { +public: + using RangeBaseT::RangeBaseT; + TypeRange(ArrayRef types = llvm::None); + explicit TypeRange(OperandRange values); + explicit TypeRange(ResultRange values); + explicit TypeRange(ValueRange values); + template + TypeRange(ValueTypeRange values) + : TypeRange(ValueRangeT(values.begin().getCurrent(), + values.end().getCurrent())) {} + +private: + /// The owner of the range is either: + /// * A pointer to the first element of an array of values. + /// * A pointer to the first element of an array of types. + /// * A pointer to the first element of an array of operands. + using OwnerT = llvm::PointerUnion; + + /// See `detail::indexed_accessor_range_base` for details. + static OwnerT offset_base(OwnerT object, ptrdiff_t index); + /// See `detail::indexed_accessor_range_base` for details. + static Type dereference_iterator(OwnerT object, ptrdiff_t index); + + /// Allow access to `offset_base` and `dereference_iterator`. + friend RangeBaseT; +}; + +//===----------------------------------------------------------------------===// // ValueTypeRange /// This class implements iteration on the types of a given range of values. @@ -555,6 +598,18 @@ : llvm::mapped_iterator(it, &unwrap) {} }; +/// This class implements iteration on the types of a given range of values. +template +class ValueTypeRange final + : public llvm::iterator_range< + ValueTypeIterator> { +public: + using llvm::iterator_range< + ValueTypeIterator>::iterator_range; + template + ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {} +}; + //===----------------------------------------------------------------------===// // OperandRange @@ -568,7 +623,8 @@ /// Returns the types of the values within this range. using type_iterator = ValueTypeIterator; - iterator_range getTypes() const { return {begin(), end()}; } + using type_range = ValueTypeRange; + type_range getTypes() const { return {begin(), end()}; } private: /// See `detail::indexed_accessor_range_base` for details. @@ -598,7 +654,8 @@ /// Returns the types of the values within this range. using type_iterator = ArrayRef::iterator; - ArrayRef getTypes() const; + using type_range = ArrayRef; + type_range getTypes() const; private: /// See `indexed_accessor_range` for details. @@ -666,7 +723,8 @@ /// Returns the types of the values within this range. using type_iterator = ValueTypeIterator; - iterator_range getTypes() const { return {begin(), end()}; } + using type_range = ValueTypeRange; + type_range getTypes() const { return {begin(), end()}; } private: using OwnerT = detail::ValueRangeOwner; diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -141,6 +141,43 @@ //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// +// TypeRange + +TypeRange::TypeRange(ArrayRef types) + : TypeRange(types.data(), types.size()) {} +TypeRange::TypeRange(OperandRange values) + : TypeRange(values.begin().getBase(), values.size()) {} +TypeRange::TypeRange(ResultRange values) + : TypeRange(values.getBase()->getResultTypes().slice(values.getStartIndex(), + values.size())) {} +TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) { + detail::ValueRangeOwner owner = values.begin().getBase(); + if (auto *op = reinterpret_cast(owner.ptr.dyn_cast())) + this->base = &op->getResultTypes()[owner.startIndex]; + else if (auto *operand = owner.ptr.dyn_cast()) + this->base = operand; + else + this->base = owner.ptr.get(); +} + +/// See `detail::indexed_accessor_range_base` for details. +TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) { + if (auto *value = object.dyn_cast()) + return {value + index}; + if (auto *operand = object.dyn_cast()) + return {operand + index}; + return {object.dyn_cast() + index}; +} +/// See `detail::indexed_accessor_range_base` for details. +Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) { + if (auto *value = object.dyn_cast()) + return (value + index)->getType(); + if (auto *operand = object.dyn_cast()) + return (operand + index)->get().getType(); + return object.dyn_cast()[index]; +} + +//===----------------------------------------------------------------------===// // OperandRange OperandRange::OperandRange(Operation *op)