diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h --- a/mlir/include/mlir/IR/BlockSupport.h +++ b/mlir/include/mlir/IR/BlockSupport.h @@ -55,21 +55,32 @@ /// This class implements the successor iterators for Block. class SuccessorRange final : public llvm::detail::indexed_accessor_range_base< - SuccessorRange, BlockOperand *, Block *, Block *, Block *> { + SuccessorRange, llvm::PointerUnion, + Block *, Block *, Block *> { public: using RangeBaseT::RangeBaseT; + SuccessorRange(ArrayRef blocks = llvm::None); SuccessorRange(Block *block); SuccessorRange(Operation *term); + template , Arg>::value>> + SuccessorRange(Arg &&arg) + : SuccessorRange(ArrayRef(std::forward(arg))) {} + SuccessorRange(std::initializer_list blocks) + : SuccessorRange(ArrayRef(blocks)) {} private: + /// The owner of the range is either: + /// * A pointer to the first element of an array of block operands. + /// * A pointer to the first element of an array of Block *. + using OwnerT = llvm::PointerUnion; + /// See `llvm::detail::indexed_accessor_range_base` for details. - static BlockOperand *offset_base(BlockOperand *object, ptrdiff_t index) { - return object + index; - } + static OwnerT offset_base(OwnerT object, ptrdiff_t index); + /// See `llvm::detail::indexed_accessor_range_base` for details. - static Block *dereference_iterator(BlockOperand *object, ptrdiff_t index) { - return object[index].get(); - } + static Block *dereference_iterator(OwnerT object, ptrdiff_t index); /// Allow access to `offset_base` and `dereference_iterator`. friend RangeBaseT; diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -32,25 +32,25 @@ public: /// Create a new Operation with the specific fields. static Operation *create(Location location, OperationName name, - ArrayRef resultTypes, ArrayRef operands, + TypeRange resultTypes, ValueRange operands, ArrayRef attributes, - ArrayRef successors, unsigned numRegions); + SuccessorRange successors, unsigned numRegions); /// Overload of create that takes an existing MutableDictionaryAttr to avoid /// unnecessarily uniquing a list of attributes. static Operation *create(Location location, OperationName name, - ArrayRef resultTypes, ArrayRef operands, + TypeRange resultTypes, ValueRange operands, MutableDictionaryAttr attributes, - ArrayRef successors, unsigned numRegions); + SuccessorRange successors, unsigned numRegions); /// Create a new Operation from the fields stored in `state`. static Operation *create(const OperationState &state); /// Create a new Operation with the specific fields. static Operation *create(Location location, OperationName name, - ArrayRef resultTypes, ArrayRef operands, + TypeRange resultTypes, ValueRange operands, MutableDictionaryAttr attributes, - ArrayRef successors = {}, + SuccessorRange successors = {}, RegionRange regions = {}); /// The name of an operation is the key identifier for it. @@ -633,7 +633,7 @@ bool hasValidOrder() { return orderIndex != kInvalidOrderIdx; } private: - Operation(Location location, OperationName name, ArrayRef resultTypes, + Operation(Location location, OperationName name, TypeRange resultTypes, unsigned numSuccessors, unsigned numRegions, const MutableDictionaryAttr &attributes, bool hasOperandStorage); diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -397,10 +397,6 @@ attributes.append(newAttributes); } - /// Add an array of successors. - void addSuccessors(ArrayRef newSuccessors) { - successors.append(newSuccessors.begin(), newSuccessors.end()); - } void addSuccessors(Block *successor) { successors.push_back(successor); } void addSuccessors(SuccessorRange newSuccessors); diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -282,16 +282,37 @@ } //===----------------------------------------------------------------------===// -// Successors +// SuccessorRange //===----------------------------------------------------------------------===// -SuccessorRange::SuccessorRange(Block *block) : SuccessorRange(nullptr, 0) { +SuccessorRange::SuccessorRange(ArrayRef blocks) + : SuccessorRange(nullptr, 0) { + if ((count = blocks.size())) + base = blocks.data(); +} + +SuccessorRange::SuccessorRange(Block *block) : SuccessorRange() { if (Operation *term = block->getTerminator()) if ((count = term->getNumSuccessors())) base = term->getBlockOperands().data(); } -SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange(nullptr, 0) { +SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange() { if ((count = term->getNumSuccessors())) base = term->getBlockOperands().data(); } + +/// See `llvm::detail::indexed_accessor_range_base` for details. +SuccessorRange::OwnerT SuccessorRange::offset_base(OwnerT object, + ptrdiff_t index) { + if (auto *operand = object.dyn_cast()) + return {operand + index}; + return {object.dyn_cast() + index}; +} + +/// See `llvm::detail::indexed_accessor_range_base` for details. +Block *SuccessorRange::dereference_iterator(OwnerT object, ptrdiff_t index) { + if (const auto *operand = object.dyn_cast()) + return operand[index].get(); + return object.dyn_cast()[index]; +} diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -66,29 +66,24 @@ /// Create a new Operation with the specific fields. Operation *Operation::create(Location location, OperationName name, - ArrayRef resultTypes, - ArrayRef operands, + TypeRange resultTypes, ValueRange operands, ArrayRef attributes, - ArrayRef successors, - unsigned numRegions) { + SuccessorRange successors, unsigned numRegions) { return create(location, name, resultTypes, operands, MutableDictionaryAttr(attributes), successors, numRegions); } /// Create a new Operation from operation state. Operation *Operation::create(const OperationState &state) { - return Operation::create(state.location, state.name, state.types, - state.operands, state.attributes, state.successors, - state.regions); + return create(state.location, state.name, state.types, state.operands, + state.attributes, state.successors, state.regions); } /// Create a new Operation with the specific fields. Operation *Operation::create(Location location, OperationName name, - ArrayRef resultTypes, - ArrayRef operands, + TypeRange resultTypes, ValueRange operands, MutableDictionaryAttr attributes, - ArrayRef successors, - RegionRange regions) { + SuccessorRange successors, RegionRange regions) { unsigned numRegions = regions.size(); Operation *op = create(location, name, resultTypes, operands, attributes, successors, numRegions); @@ -101,11 +96,9 @@ /// Overload of create that takes an existing MutableDictionaryAttr to avoid /// unnecessarily uniquing a list of attributes. Operation *Operation::create(Location location, OperationName name, - ArrayRef resultTypes, - ArrayRef operands, + TypeRange resultTypes, ValueRange operands, MutableDictionaryAttr attributes, - ArrayRef successors, - unsigned numRegions) { + SuccessorRange successors, unsigned numRegions) { // We only need to allocate additional memory for a subset of results. unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); @@ -162,7 +155,7 @@ } Operation::Operation(Location location, OperationName name, - ArrayRef resultTypes, unsigned numSuccessors, + TypeRange resultTypes, unsigned numSuccessors, unsigned numRegions, const MutableDictionaryAttr &attributes, bool hasOperandStorage) @@ -606,8 +599,8 @@ successors.push_back(mapper.lookupOrDefault(successor)); // Create the new operation. - auto *newOp = Operation::create(getLoc(), getName(), getResultTypes(), - operands, attrs, successors, getNumRegions()); + auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, + successors, getNumRegions()); // Remember the mapping of any results. for (unsigned i = 0, e = getNumResults(); i != e; ++i)