diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -249,7 +249,8 @@ namespace detail { /// The internal implementation of a BlockArgument. class BlockArgumentImpl : public IRObjectWithUseList { - BlockArgumentImpl(Type type, Block *owner) : type(type), owner(owner) {} + BlockArgumentImpl(Type type, Block *owner, int64_t index) + : type(type), owner(owner), index(index) {} /// The type of this argument. Type type; @@ -257,6 +258,9 @@ /// The owner of this argument. Block *owner; + /// The position in the argument list. + int64_t index; + /// Allow access to owner and constructor. friend BlockArgument; }; @@ -281,12 +285,12 @@ void setType(Type newType) { getImpl()->type = newType; } /// Returns the number of this argument. - unsigned getArgNumber() const; + unsigned getArgNumber() const { return getImpl()->index; } private: /// Allocate a new argument with the given type and owner. - static BlockArgument create(Type type, Block *owner) { - return new detail::BlockArgumentImpl(type, owner); + static BlockArgument create(Type type, Block *owner, int64_t index) { + return new detail::BlockArgumentImpl(type, owner, index); } /// Destroy and deallocate this argument. @@ -298,7 +302,10 @@ ownerAndKind.getPointer()); } - /// Allow access to `create` and `destroy`. + /// Cache the position in the block argument list. + void setArgNumber(int64_t index) { getImpl()->index = index; } + + /// Allow access to `create`, `destroy` and `setArgNumber`. friend Block; /// Allow access to 'getImpl'. diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -12,17 +12,6 @@ #include "llvm/ADT/BitVector.h" using namespace mlir; -//===----------------------------------------------------------------------===// -// BlockArgument -//===----------------------------------------------------------------------===// - -/// Returns the number of this argument. -unsigned BlockArgument::getArgNumber() const { - // Arguments are not stored in place, so we have to find it within the list. - auto argList = getOwner()->getArguments(); - return std::distance(argList.begin(), llvm::find(argList, *this)); -} - //===----------------------------------------------------------------------===// // Block //===----------------------------------------------------------------------===// @@ -150,7 +139,7 @@ } BlockArgument Block::addArgument(Type type) { - BlockArgument arg = BlockArgument::create(type, this); + BlockArgument arg = BlockArgument::create(type, this, arguments.size()); arguments.push_back(arg); return arg; } @@ -165,16 +154,31 @@ } BlockArgument Block::insertArgument(unsigned index, Type type) { - auto arg = BlockArgument::create(type, this); + auto arg = BlockArgument::create(type, this, index); assert(index <= arguments.size()); arguments.insert(arguments.begin() + index, arg); + // Update the cached position for all the arguments after the newly inserted + // one. + ++index; + for (BlockArgument arg : llvm::drop_begin(arguments, index)) + arg.setArgNumber(index++); return arg; } +/// Insert one value to the given position of the argument list. The existing +/// arguments are shifted. The block is expected not to have predecessors. +BlockArgument Block::insertArgument(args_iterator it, Type type) { + assert(llvm::empty(getPredecessors()) && + "cannot insert arguments to blocks with predecessors"); + return insertArgument(it->getArgNumber(), type); +} + void Block::eraseArgument(unsigned index) { assert(index < arguments.size()); arguments[index].destroy(); arguments.erase(arguments.begin() + index); + for (BlockArgument arg : llvm::drop_begin(arguments, index)) + arg.setArgNumber(index++); } void Block::eraseArguments(ArrayRef argIndices) { @@ -188,23 +192,18 @@ // We do this in reverse so that we erase later indices before earlier // indices, to avoid shifting the later indices. unsigned originalNumArgs = getNumArguments(); - for (unsigned i = 0; i < originalNumArgs; ++i) - if (eraseIndices.test(originalNumArgs - i - 1)) - eraseArgument(originalNumArgs - i - 1); -} - -/// Insert one value to the given position of the argument list. The existing -/// arguments are shifted. The block is expected not to have predecessors. -BlockArgument Block::insertArgument(args_iterator it, Type type) { - assert(llvm::empty(getPredecessors()) && - "cannot insert arguments to blocks with predecessors"); - - // Use the args_iterator (on the BlockArgListType) to compute the insertion - // iterator in the underlying argument storage. - size_t distance = std::distance(args_begin(), it); - auto arg = BlockArgument::create(type, this); - arguments.insert(std::next(arguments.begin(), distance), arg); - return arg; + for (unsigned i = 0; i < originalNumArgs; ++i) { + int64_t currentPos = originalNumArgs - i - 1; + if (eraseIndices.test(currentPos)) { + arguments[currentPos].destroy(); + arguments.erase(arguments.begin() + currentPos); + } + } + // Update the cached position for the arguments after the first erased one. + int64_t index = 0; + for (BlockArgument arg : + llvm::drop_begin(arguments, eraseIndices.find_first())) + arg.setArgNumber(index++); } //===----------------------------------------------------------------------===//