diff --git a/mlir/lib/Transforms/CopyRemoval.cpp b/mlir/lib/Transforms/CopyRemoval.cpp --- a/mlir/lib/Transforms/CopyRemoval.cpp +++ b/mlir/lib/Transforms/CopyRemoval.cpp @@ -19,16 +19,28 @@ //===----------------------------------------------------------------------===// // CopyRemovalPass //===----------------------------------------------------------------------===// + /// This pass removes the redundant Copy operations. Additionally, it /// removes the leftover definition and deallocation operations by erasing the /// copy operation. class CopyRemovalPass : public PassWrapper> { +public: + void runOnOperation() override { + getOperation()->walk([&](CopyOpInterface copyOp) { + reuseCopySourceAsTarget(copyOp); + reuseCopyTargetAsSource(copyOp); + }); + for (Operation *op : eraseList) + op->erase(); + } + private: /// List of operations that need to be removed. DenseSet eraseList; /// Returns the deallocation operation for `value` in `block` if it exists. Operation *getDeallocationInBlock(Value value, Block *block) { + assert(block && "Block cannot be null"); auto valueUsers = value.getUsers(); auto it = llvm::find_if(valueUsers, [&](Operation *op) { auto effects = dyn_cast(op); @@ -40,12 +52,12 @@ /// Returns true if an operation between start and end operations has memory /// effect. bool hasMemoryEffectOpBetween(Operation *start, Operation *end) { + assert((start || end) && "Start and end operations cannot be null"); assert(start->getBlock() == end->getBlock() && "Start and end operations should be in the same block."); Operation *op = start->getNextNode(); while (op->isBeforeInBlock(end)) { - auto effects = dyn_cast(op); - if (effects) + if (isa(op)) return true; op = op->getNextNode(); } @@ -55,6 +67,7 @@ /// Returns true if `val` value has at least a user between `start` and /// `end` operations. bool hasUsersBetween(Value val, Operation *start, Operation *end) { + assert((start || end) && "Start and end operations cannot be null"); Block *block = start->getBlock(); assert(block == end->getBlock() && "Start and end operations should be in the same block."); @@ -65,10 +78,11 @@ }; bool areOpsInTheSameBlock(ArrayRef operations) { - llvm::SmallPtrSet blocks; - for (Operation *op : operations) - blocks.insert(op->getBlock()); - return blocks.size() == 1; + assert(!operations.empty() && + "The operations list should contain at least a single operation"); + Block *block = operations.front()->getBlock(); + return llvm::none_of( + operations, [&](Operation *op) { return block != op->getBlock(); }); } /// Input: @@ -97,7 +111,7 @@ /// TODO: Alias analysis is not available at the moment. Currently, we check /// if there are any operations with memory effects between copy and /// deallocation operations. - void ReuseCopySourceAsTarget(CopyOpInterface copyOp) { + void reuseCopySourceAsTarget(CopyOpInterface copyOp) { if (eraseList.count(copyOp)) return; @@ -147,7 +161,7 @@ /// TODO: Alias analysis is not available at the moment. Currently, we check /// if there are any operations with memory effects between copy and /// deallocation operations. - void ReuseCopyTargetAsSource(CopyOpInterface copyOp) { + void reuseCopyTargetAsSource(CopyOpInterface copyOp) { if (eraseList.count(copyOp)) return; @@ -169,16 +183,6 @@ eraseList.insert(fromDefiningOp); eraseList.insert(fromFreeingOp); } - -public: - void runOnOperation() override { - getOperation()->walk([&](CopyOpInterface copyOp) { - ReuseCopySourceAsTarget(copyOp); - ReuseCopyTargetAsSource(copyOp); - }); - for (Operation *op : eraseList) - op->erase(); - } }; } // end anonymous namespace @@ -186,6 +190,7 @@ //===----------------------------------------------------------------------===// // CopyRemovalPass construction //===----------------------------------------------------------------------===// + std::unique_ptr mlir::createCopyRemovalPass() { return std::make_unique(); }