diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -108,7 +108,10 @@ void eraseArguments(ArrayRef argIndices); /// Erases the arguments that have their corresponding bit set in /// `eraseIndices` and removes them from the argument list. - void eraseArguments(llvm::BitVector eraseIndices); + void eraseArguments(const llvm::BitVector &eraseIndices); + /// Erases arguments using the given predicate. If the predicate returns true, + /// that argument is erased. + void eraseArguments(function_ref shouldEraseFn); unsigned getNumArguments() { return arguments.size(); } BlockArgument getArgument(unsigned i) { return arguments[i]; } diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -188,23 +188,32 @@ eraseArguments(eraseIndices); } -void Block::eraseArguments(llvm::BitVector eraseIndices) { - // We do this in reverse so that we erase later indices before earlier - // indices, to avoid shifting the later indices. - unsigned originalNumArgs = getNumArguments(); - int64_t firstErased = originalNumArgs; - for (unsigned i = 0; i < originalNumArgs; ++i) { - int64_t currentPos = originalNumArgs - i - 1; - if (eraseIndices.test(currentPos)) { - arguments[currentPos].destroy(); - arguments.erase(arguments.begin() + currentPos); - firstErased = currentPos; +void Block::eraseArguments(const llvm::BitVector &eraseIndices) { + eraseArguments( + [&](BlockArgument arg) { return eraseIndices.test(arg.getArgNumber()); }); +} + +void Block::eraseArguments(function_ref shouldEraseFn) { + auto firstDead = llvm::find_if(arguments, shouldEraseFn); + if (firstDead == arguments.end()) + return; + + // Destroy the first dead argument, this avoids reapplying the predicate to + // it. + unsigned index = firstDead->getArgNumber(); + firstDead->destroy(); + + // Iterate the remaining arguments to remove any that are now dead. + for (auto it = std::next(firstDead), e = arguments.end(); it != e; ++it) { + // Destroy dead arguments, and shift those that are still live. + if (shouldEraseFn(*it)) { + it->destroy(); + } else { + it->setArgNumber(index++); + *firstDead++ = *it; } } - // Update the cached position for the arguments after the first erased one. - int64_t index = firstErased; - for (BlockArgument arg : llvm::drop_begin(arguments, index)) - arg.setArgNumber(index++); + arguments.erase(firstDead, arguments.end()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -139,9 +139,23 @@ class LiveMap { public: /// Value methods. - bool wasProvenLive(Value value) { return liveValues.count(value); } + bool wasProvenLive(Value value) { + // TODO: For results that are removable, e.g. for region based control flow, + // we could allow for these values to be tracked independently. + if (OpResult result = value.dyn_cast()) + return wasProvenLive(result.getOwner()); + return wasProvenLive(value.cast()); + } + bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); } void setProvedLive(Value value) { - changed |= liveValues.insert(value).second; + // TODO: For results that are removable, e.g. for region based control flow, + // we could allow for these values to be tracked independently. + if (OpResult result = value.dyn_cast()) + return setProvedLive(result.getOwner()); + setProvedLive(value.cast()); + } + void setProvedLive(BlockArgument arg) { + changed |= liveValues.insert(arg).second; } /// Operation methods. @@ -192,15 +206,6 @@ liveMap.setProvedLive(value); } -static bool isOpIntrinsicallyLive(Operation *op) { - // This pass doesn't modify the CFG, so terminators are never deleted. - if (op->mightHaveTrait()) - return true; - // If the op has a side effect, we treat it as live. - // TODO: Properly handle region side effects. - return !MemoryEffectOpInterface::hasNoEffect(op) || op->getNumRegions() != 0; -} - static void propagateLiveness(Region ®ion, LiveMap &liveMap); static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) { @@ -226,9 +231,6 @@ } static void propagateLiveness(Operation *op, LiveMap &liveMap) { - // All Value's are either a block argument or an op result. - // We call processValue on those cases. - // Recurse on any regions the op has. for (Region ®ion : op->getRegions()) propagateLiveness(region, liveMap); @@ -237,18 +239,17 @@ if (op->hasTrait()) return propagateTerminatorLiveness(op, liveMap); - // Process the op itself. - if (isOpIntrinsicallyLive(op)) { - liveMap.setProvedLive(op); + // Don't reprocess live operations. + if (liveMap.wasProvenLive(op)) return; - } + + // Process the op itself. + if (!wouldOpBeTriviallyDead(op)) + return liveMap.setProvedLive(op); + + // If the op isn't intrinsically alive, check it's results. for (Value value : op->getResults()) processValue(value, liveMap); - bool provedLive = llvm::any_of(op->getResults(), [&](Value value) { - return liveMap.wasProvenLive(value); - }); - if (provedLive) - liveMap.setProvedLive(op); } static void propagateLiveness(Region ®ion, LiveMap &liveMap) { @@ -260,8 +261,18 @@ // faster convergence to a fixed point (we try to visit uses before defs). for (Operation &op : llvm::reverse(block->getOperations())) propagateLiveness(&op, liveMap); - for (Value value : block->getArguments()) - processValue(value, liveMap); + + // We currently do not remove entry block arguments, so there is no need to + // track their liveness. + // TODO: We could track these and enable removing dead operands/arguments + // from region control flow operations. + if (block->isEntryBlock()) + continue; + + for (Value value : block->getArguments()) { + if (!liveMap.wasProvenLive(value)) + processValue(value, liveMap); + } } } @@ -314,11 +325,12 @@ eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap); for (Operation &childOp : llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) { - erasedAnything |= - succeeded(deleteDeadness(childOp.getRegions(), liveMap)); if (!liveMap.wasProvenLive(&childOp)) { erasedAnything = true; childOp.erase(); + } else { + erasedAnything |= + succeeded(deleteDeadness(childOp.getRegions(), liveMap)); } } } @@ -326,13 +338,8 @@ // The entry block has an unknown contract with their enclosing block, so // skip it. for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) { - // Iterate in reverse to avoid shifting later arguments when deleting - // earlier arguments. - for (unsigned i = 0, e = block.getNumArguments(); i < e; i++) - if (!liveMap.wasProvenLive(block.getArgument(e - i - 1))) { - block.eraseArgument(e - i - 1); - erasedAnything = true; - } + block.eraseArguments( + [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); }); } } return success(erasedAnything);