diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h --- a/mlir/include/mlir/IR/BlockSupport.h +++ b/mlir/include/mlir/IR/BlockSupport.h @@ -75,6 +75,47 @@ friend RangeBaseT; }; +//===----------------------------------------------------------------------===// +// BlockRange +//===----------------------------------------------------------------------===// + +/// This class provides an abstraction over the different types of ranges over +/// Blocks. In many cases, this prevents the need to explicitly materialize a +/// SmallVector/std::vector. This class should be used in places that are not +/// suitable for a more derived type (e.g. ArrayRef) or a template range +/// parameter. +class BlockRange final + : public llvm::detail::indexed_accessor_range_base< + BlockRange, llvm::PointerUnion, + Block *, Block *, Block *> { +public: + using RangeBaseT::RangeBaseT; + BlockRange(ArrayRef blocks = llvm::None); + BlockRange(SuccessorRange successors); + template , Arg>::value>> + BlockRange(Arg &&arg) + : BlockRange(ArrayRef(std::forward(arg))) {} + BlockRange(std::initializer_list blocks) + : BlockRange(ArrayRef(blocks)) {} + +private: + /// The owner of the range is either: + /// * A pointer to the first element of an array of block operands. + /// * A pointer to the first element of an array of Block *. + using OwnerT = llvm::PointerUnion; + + /// See `llvm::detail::indexed_accessor_range_base` for details. + static OwnerT offset_base(OwnerT object, ptrdiff_t index); + + /// See `llvm::detail::indexed_accessor_range_base` for details. + static Block *dereference_iterator(OwnerT object, ptrdiff_t index); + + /// Allow access to `offset_base` and `dereference_iterator`. + friend RangeBaseT; +}; + //===----------------------------------------------------------------------===// // Operation Iterators //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -32,25 +32,25 @@ public: /// Create a new Operation with the specific fields. static Operation *create(Location location, OperationName name, - ArrayRef resultTypes, ArrayRef operands, + TypeRange resultTypes, ValueRange operands, ArrayRef attributes, - ArrayRef successors, unsigned numRegions); + BlockRange successors, unsigned numRegions); /// Overload of create that takes an existing MutableDictionaryAttr to avoid /// unnecessarily uniquing a list of attributes. static Operation *create(Location location, OperationName name, - ArrayRef resultTypes, ArrayRef operands, + TypeRange resultTypes, ValueRange operands, MutableDictionaryAttr attributes, - ArrayRef successors, unsigned numRegions); + BlockRange successors, unsigned numRegions); /// Create a new Operation from the fields stored in `state`. static Operation *create(const OperationState &state); /// Create a new Operation with the specific fields. static Operation *create(Location location, OperationName name, - ArrayRef resultTypes, ArrayRef operands, + TypeRange resultTypes, ValueRange operands, MutableDictionaryAttr attributes, - ArrayRef successors = {}, + BlockRange successors = {}, RegionRange regions = {}); /// The name of an operation is the key identifier for it. @@ -633,7 +633,7 @@ bool hasValidOrder() { return orderIndex != kInvalidOrderIdx; } private: - Operation(Location location, OperationName name, ArrayRef resultTypes, + Operation(Location location, OperationName name, TypeRange resultTypes, unsigned numSuccessors, unsigned numRegions, const MutableDictionaryAttr &attributes, bool hasOperandStorage); diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -29,6 +29,7 @@ namespace mlir { class Block; +class BlockRange; class Dialect; class Operation; struct OperationState; @@ -42,7 +43,6 @@ class Region; class ResultRange; class RewritePattern; -class SuccessorRange; class Type; class Value; class ValueRange; @@ -394,12 +394,8 @@ attributes.append(newAttributes); } - /// Add an array of successors. - void addSuccessors(ArrayRef newSuccessors) { - successors.append(newSuccessors.begin(), newSuccessors.end()); - } void addSuccessors(Block *successor) { successors.push_back(successor); } - void addSuccessors(SuccessorRange newSuccessors); + void addSuccessors(BlockRange newSuccessors); /// Create a region that should be attached to the operation. These regions /// can be filled in immediately without waiting for Operation to be diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -282,7 +282,7 @@ } //===----------------------------------------------------------------------===// -// Successors +// SuccessorRange //===----------------------------------------------------------------------===// SuccessorRange::SuccessorRange(Block *block) : SuccessorRange(nullptr, 0) { @@ -295,3 +295,29 @@ if ((count = term->getNumSuccessors())) base = term->getBlockOperands().data(); } + +//===----------------------------------------------------------------------===// +// BlockRange +//===----------------------------------------------------------------------===// + +BlockRange::BlockRange(ArrayRef blocks) : BlockRange(nullptr, 0) { + if ((count = blocks.size())) + base = blocks.data(); +} + +BlockRange::BlockRange(SuccessorRange successors) + : BlockRange(successors.begin().getBase(), successors.size()) {} + +/// See `llvm::detail::indexed_accessor_range_base` for details. +BlockRange::OwnerT BlockRange::offset_base(OwnerT object, ptrdiff_t index) { + if (auto *operand = object.dyn_cast()) + return {operand + index}; + return {object.dyn_cast() + index}; +} + +/// See `llvm::detail::indexed_accessor_range_base` for details. +Block *BlockRange::dereference_iterator(OwnerT object, ptrdiff_t index) { + if (const auto *operand = object.dyn_cast()) + return operand[index].get(); + return object.dyn_cast()[index]; +} diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -71,29 +71,24 @@ /// Create a new Operation with the specific fields. Operation *Operation::create(Location location, OperationName name, - ArrayRef resultTypes, - ArrayRef operands, + TypeRange resultTypes, ValueRange operands, ArrayRef attributes, - ArrayRef successors, - unsigned numRegions) { + BlockRange successors, unsigned numRegions) { return create(location, name, resultTypes, operands, MutableDictionaryAttr(attributes), successors, numRegions); } /// Create a new Operation from operation state. Operation *Operation::create(const OperationState &state) { - return Operation::create(state.location, state.name, state.types, - state.operands, state.attributes, state.successors, - state.regions); + return create(state.location, state.name, state.types, state.operands, + state.attributes, state.successors, state.regions); } /// Create a new Operation with the specific fields. Operation *Operation::create(Location location, OperationName name, - ArrayRef resultTypes, - ArrayRef operands, + TypeRange resultTypes, ValueRange operands, MutableDictionaryAttr attributes, - ArrayRef successors, - RegionRange regions) { + BlockRange successors, RegionRange regions) { unsigned numRegions = regions.size(); Operation *op = create(location, name, resultTypes, operands, attributes, successors, numRegions); @@ -106,11 +101,9 @@ /// Overload of create that takes an existing MutableDictionaryAttr to avoid /// unnecessarily uniquing a list of attributes. Operation *Operation::create(Location location, OperationName name, - ArrayRef resultTypes, - ArrayRef operands, + TypeRange resultTypes, ValueRange operands, MutableDictionaryAttr attributes, - ArrayRef successors, - unsigned numRegions) { + BlockRange successors, unsigned numRegions) { // We only need to allocate additional memory for a subset of results. unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); @@ -167,7 +160,7 @@ } Operation::Operation(Location location, OperationName name, - ArrayRef resultTypes, unsigned numSuccessors, + TypeRange resultTypes, unsigned numSuccessors, unsigned numRegions, const MutableDictionaryAttr &attributes, bool hasOperandStorage) @@ -611,8 +604,8 @@ successors.push_back(mapper.lookupOrDefault(successor)); // Create the new operation. - auto *newOp = Operation::create(getLoc(), getName(), getResultTypes(), - operands, attrs, successors, getNumRegions()); + auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, + successors, getNumRegions()); // Remember the mapping of any results. for (unsigned i = 0, e = getNumResults(); i != e; ++i) diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -186,7 +186,7 @@ operands.append(newOperands.begin(), newOperands.end()); } -void OperationState::addSuccessors(SuccessorRange newSuccessors) { +void OperationState::addSuccessors(BlockRange newSuccessors) { successors.append(newSuccessors.begin(), newSuccessors.end()); }