diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -243,6 +243,12 @@ getOperandStorage().eraseOperands(idx, length); } + /// Erases the operands that have their corresponding bit set in + /// `eraseIndices` and removes them from the operand list. + void eraseOperands(const llvm::BitVector &eraseIndices) { + getOperandStorage().eraseOperands(eraseIndices); + } + // Support operand iteration. using operand_range = OperandRange; using operand_iterator = operand_range::iterator; diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -28,6 +28,10 @@ #include "llvm/Support/TrailingObjects.h" #include +namespace llvm { +class BitVector; +} // end namespace llvm + namespace mlir { class Dialect; class DictionaryAttr; @@ -495,6 +499,10 @@ /// Erase the operands held by the storage within the given range. void eraseOperands(unsigned start, unsigned length); + /// Erase the operands held by the storage that have their corresponding bit + /// set in `eraseIndices`. + void eraseOperands(const llvm::BitVector &eraseIndices); + /// Get the operation operands held by the storage. MutableArrayRef getOperands() { return getStorage().getOperands(); diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -12,10 +12,10 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/OperationSupport.h" -#include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Operation.h" +#include "llvm/ADT/BitVector.h" + using namespace mlir; //===----------------------------------------------------------------------===// @@ -300,6 +300,26 @@ operands[storage.numOperands + i].~OpOperand(); } +void detail::OperandStorage::eraseOperands( + const llvm::BitVector &eraseIndices) { + TrailingOperandStorage &storage = getStorage(); + MutableArrayRef operands = storage.getOperands(); + assert(eraseIndices.size() == operands.size()); + + // Check that at least one operand is erased. + int firstErasedIndice = eraseIndices.find_first(); + if (firstErasedIndice == -1) + return; + + // Shift all of the removed operands to the end, and destroy them. + storage.numOperands = firstErasedIndice; + for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i) + if (!eraseIndices.test(i)) + operands[storage.numOperands++] = std::move(operands[i]); + for (OpOperand &operand : operands.drop_front(storage.numOperands)) + operand.~OpOperand(); +} + /// Resize the storage to the given size. Returns the array containing the new /// operands. MutableArrayRef detail::OperandStorage::resize(Operation *owner, diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp --- a/mlir/unittests/IR/OperationSupportTest.cpp +++ b/mlir/unittests/IR/OperationSupportTest.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/BitVector.h" #include "gtest/gtest.h" using namespace mlir; @@ -150,6 +151,37 @@ useOp->destroy(); } +TEST(OperandStorageTest, RangeErase) { + MLIRContext context; + Builder builder(&context); + + Type type = builder.getNoneType(); + Operation *useOp = createOp(&context, /*operands=*/llvm::None, {type, type}); + Value operand1 = useOp->getResult(0); + Value operand2 = useOp->getResult(1); + + // Create an operation with operands to erase. + Operation *user = + createOp(&context, {operand2, operand1, operand2, operand1}); + llvm::BitVector eraseIndices(user->getNumOperands()); + + // Check erasing no operands. + user->eraseOperands(eraseIndices); + EXPECT_EQ(user->getNumOperands(), 4u); + + // Check erasing disjoint operands. + eraseIndices.set(0); + eraseIndices.set(3); + user->eraseOperands(eraseIndices); + EXPECT_EQ(user->getNumOperands(), 2u); + EXPECT_EQ(user->getOperand(0), operand1); + EXPECT_EQ(user->getOperand(1), operand2); + + // Destroy the operations. + user->destroy(); + useOp->destroy(); +} + TEST(OperationOrderTest, OrderIsAlwaysValid) { MLIRContext context; Builder builder(&context);