diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -74,6 +74,9 @@ ArrayRef disabledPatterns = llvm::None, ArrayRef enabledPatterns = llvm::None); +/// Create a pass that removes unnecessary Copy operations. +std::unique_ptr> createCopyRemovalPass(); + /// Creates a pass to perform common sub expression elimination. std::unique_ptr createCSEPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -323,6 +323,11 @@ ]; } +def CopyRemoval : FunctionPass<"copy-removal"> { + let summary = "Remove redundant copy operations"; + let constructor = "mlir::createCopyRemovalPass()"; +} + def Inliner : Pass<"inline"> { let summary = "Inline function calls"; let constructor = "mlir::createInlinerPass()"; diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ BufferResultsToOutParams.cpp BufferUtils.cpp Canonicalizer.cpp + CopyRemoval.cpp CSE.cpp Inliner.cpp LocationSnapshot.cpp diff --git a/mlir/lib/Transforms/CopyRemoval.cpp b/mlir/lib/Transforms/CopyRemoval.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/CopyRemoval.cpp @@ -0,0 +1,295 @@ +//===- CopyRemoval.cpp - Removes redundant copies -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Passes.h" + +#include "PassDetail.h" +#include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" + +#define DEBUG_TYPE "copy-removal" + +using namespace mlir; +using namespace MemoryEffects; + +namespace { + +//===----------------------------------------------------------------------===// +// CopyRemovalPass +//===----------------------------------------------------------------------===// + +/// This pass removes redundant copy operations. Additionally, it +/// removes leftover definition and deallocation operations by erasing the +/// copy operation. +/// func() { +/// %source = alloc() +/// write_to(%source) +/// %destination = alloc() +/// copy(%source, %destination) +/// dealloc(%source) +/// return %destination +/// } +/// +/// Output: +/// func(){ +/// %source = alloc() +/// write_to(%source) +/// return %source +/// } +/// Constraints: +/// 1) There should not exist any users of `destination` before the copy op. +/// 2) There should not be any write operations on `source` and `destination` +/// after copy op. +struct CopyRemovalPass : public CopyRemovalBase { +public: + void runOnFunction() override; + +private: + /// Returns the allocation operation for `value` if it exists. + /// nullptr otherwise. + Operation *getAllocationOp(Value value) { + if (Operation *op = value.getDefiningOp()) { + if (auto effects = dyn_cast(op)) + if (effects.hasEffect()) + return op; + } + return nullptr; + } + + /// Returns the deallocation operation for `value` if it exists. + /// nullptr otherwise. + Operation *getDeallocationOp(Value value) { + auto valueUsers = value.getUsers(); + auto it = llvm::find_if(valueUsers, [&](Operation *op) { + auto effects = dyn_cast(op); + return effects && effects.hasEffect(); + }); + return (it == valueUsers.end() ? nullptr : *it); + } + + /// Check whether the write effect on `val` can be caused by `op`. + static bool doesOpHaveWriteEffect(Value val, Operation *op) { + // Check whether the operation `op` has write effect on the memory. + if (auto memEffect = dyn_cast(op)) { + if (!llvm::is_contained(val.getUsers(), op)) + return false; + SmallVector effects; + memEffect.getEffects(effects); + + // Iterate through all the effects produced by `op`, and check + // whether one of them is MemoryEffects::Write. + return llvm::any_of(effects, [](MemoryEffects::EffectInstance effect) { + return isa(effect.getEffect()); + }); + } + + if (op->hasTrait()) { + // Recurse into the regions for this op and check whether the + // internal operations may have the side effect `EffectType` on + // `val`. + for (Region ®ion : op->getRegions()) { + auto walkResult = region.walk([&](Operation *op) { + if (doesOpHaveWriteEffect(val, op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return true; + } + return false; + } + + // Otherwise, conservatively assume generic operations have the effect + // on the operation. + return true; + } + + /// Check whether the write effect on `val` can be caused by `op`. + static bool doesOpUseVal(Value val, Operation *op) { + if (!llvm::is_contained(val.getUsers(), op)) + return false; + return true; + } + + /// Check if an op that lies on one of the paths between `start` + /// and `end` and satisfies `checkPropertiesOfOperation`. If the start and end + /// operations are in different regions, recursively consider all path from + /// `start` to the parent of `end` and all paths from the parent of `end` to + /// `end`. When `start` and `end` exist in the same region, perform a CFG + /// traversal to check all the relevant operations. + bool hasInterveningOp( + Value val, Operation *start, Operation *end, + std::function checkPropertiesOfOperation) { + + // Check for all paths from operation `fromp` to operation `untilOp` for the + // given property. + std::function recur = + [&](Operation *fromOp, Operation *untilOp) { + auto untilOpParentRegion = untilOp->getParentRegion(); + auto untilOpParentOp = untilOp->getParentOp(); + auto fromOpParentRegion = fromOp->getParentRegion(); + auto fromOpBlock = fromOp->getBlock(); + auto untilOpBlock = untilOp->getBlock(); + + if (!fromOpParentRegion->isAncestor(untilOpParentRegion)) + return false; + + // If the operations are in different regions, recursively consider + // all path from `fromOp` to the parent of `untilOp` and all paths + // from the parent of `untilOp` to `untilOp`. + if (fromOpParentRegion != untilOpParentRegion) { + recur(fromOp, untilOpParentOp); + if (checkPropertiesOfOperation(val, untilOpParentOp)) + return true; + return false; + } + + // Now, assuming that `fromOp` and `untilOp` exist in the same region, + // perform a CFG traversal to check all the relevant operations. + + // Additional blocks to consider. + SmallVector todoBlocks; + { + // First consider the parent block of `fromOp` an check all + // operations after `from`. + for (auto iter = ++fromOp->getIterator(), end = fromOpBlock->end(); + iter != end && &*iter != untilOp; ++iter) { + if (checkPropertiesOfOperation(val, &*iter)) + return true; + } + + // If the parent of `fromOp` doesn't contain `untilOp`, add the + // successors to the list of blocks to check. + if (untilOpBlock != fromOpBlock) + for (Block *succ : fromOpBlock->getSuccessors()) + todoBlocks.push_back(succ); + } + + // Stores the blocks whose ops has been checked using + // `checkPropertiesOfOperation`. + SmallPtrSet done; + // Traverse the CFG until hitting `untilOp`. + while (!todoBlocks.empty()) { + Block *blk = todoBlocks.pop_back_val(); + if (done.insert(blk).second) + continue; + for (Operation &op : *blk) { + if (&op == untilOp) + break; + if (checkPropertiesOfOperation(val, &op)) + return true; + if (&op == blk->getTerminator()) + for (Block *succ : blk->getSuccessors()) + todoBlocks.push_back(succ); + } + } + return false; + }; + return recur(start, end); + } + + /// Replace all occurrences of `destination` with `source` if the following + /// conditions are met. + /// 1) There should not exist any users of `destination` before the copy op. + /// 2) There should not be any write operations on `source` and `destination` + /// after copy op. + void reuseCopySourceAsTarget(CopyOpInterface copyOp, + llvm::SmallPtrSet &opsToErase) { + if (opsToErase.count(copyOp)) + return; + + Value src = copyOp.getSource(); + Value dest = copyOp.getTarget(); + + Operation *srcDeallocOp = getDeallocationOp(src); + Operation *destDeallocOp = getDeallocationOp(dest); + Operation *destDefOp = getAllocationOp(dest); + Operation *lastOpOfCurrentRegion = &src.getParentRegion()->back().back(); + Operation *lastOpUsingSrc = lastOpOfCurrentRegion; + + // If `srcDeallocOp` is not null, `lastOpUsingSrc` will be `srcDeallocOp`. + if (srcDeallocOp) + lastOpUsingSrc = srcDeallocOp; + Operation *firstOpUsingDest = &dest.getParentRegion()->front().front(); + + // If `destDefOp` is not null, `firstOpUsingDest` will be `destDefOp`. + if (destDefOp) + firstOpUsingDest = destDefOp; + + // Capture all the cases when copy removal and replacing uses of `dest` with + // `src` is not possible + if (hasInterveningOp(dest, firstOpUsingDest, copyOp, &doesOpUseVal) || + doesOpUseVal(dest, firstOpUsingDest) || + hasInterveningOp(dest, copyOp, lastOpUsingSrc, + &doesOpHaveWriteEffect) || + doesOpHaveWriteEffect(dest, lastOpUsingSrc) || + hasInterveningOp(src, copyOp, lastOpUsingSrc, &doesOpHaveWriteEffect) || + doesOpHaveWriteEffect(src, lastOpUsingSrc)) + return; + + // Erase the `copyOp`, `destDefOp` and `destDeallocOp`. Also remove + // `srcDeallocOp` if any uses of `dest` are there after `srcDeallocOp`, as + // we are replacing all instances of `dest` with `src`, and doing so will + // lead to occurrences of `src` after `srcDeallocOp`, which is semantically + // incorrect. + opsToErase.insert(copyOp); + if (destDefOp) + opsToErase.insert(destDefOp); + if (srcDeallocOp && + (hasInterveningOp(dest, srcDeallocOp, lastOpOfCurrentRegion, + &doesOpUseVal) || + doesOpUseVal(dest, lastOpOfCurrentRegion))) + opsToErase.insert(srcDeallocOp); + if (destDeallocOp) + opsToErase.insert(destDeallocOp); + + // Replace all uses of `src` with `dest`. + dest.replaceAllUsesWith(src); + } + + /// Remove copy statements when there are no uses of `destination` after the + /// copy op. + void removeCopy(CopyOpInterface copyOp, + llvm::SmallPtrSet &opsToErase) { + if (opsToErase.count(copyOp)) + return; + + Value src = copyOp.getSource(); + Value dest = copyOp.getTarget(); + Operation *lastOpUsingDest = &src.getParentRegion()->back().back(); + Operation *destDeallocOp = getDeallocationOp(dest); + if (destDeallocOp) + lastOpUsingDest = destDeallocOp; + if (!hasInterveningOp(dest, copyOp, lastOpUsingDest, &doesOpUseVal) && + (!doesOpUseVal(dest, lastOpUsingDest) || destDeallocOp)) + opsToErase.insert(copyOp); + } +}; + +void CopyRemovalPass::runOnFunction() { + /// Operations that need to be removed. + llvm::SmallPtrSet opsToErase; + getFunction()->walk([&](CopyOpInterface copyOp) { + reuseCopySourceAsTarget(copyOp, opsToErase); + removeCopy(copyOp, opsToErase); + }); + for (Operation *op : opsToErase) { + assert(op->use_empty() && + "uses remaining for copy ops, memref allocation and deallocation " + "ops that should have ready to be erased"); + op->erase(); + } + return; +} + +} // end anonymous namespace + +std::unique_ptr> mlir::createCopyRemovalPass() { + return std::make_unique(); +} diff --git a/mlir/test/Transforms/copy-removal.mlir b/mlir/test/Transforms/copy-removal.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/copy-removal.mlir @@ -0,0 +1,303 @@ +// RUN: mlir-opt -copy-removal -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func @check_copy_removal_test1 +func @check_copy_removal_test1() -> memref<5xf32> { + %dest = memref.alloc() : memref<5xf32> + %src = memref.alloc() : memref<5xf32> + memref.copy %src, %dest : memref<5xf32> to memref<5xf32> + memref.dealloc %src : memref<5xf32> + return %dest : memref<5xf32> +} +// CHECK-NEXT: %[[SRC:.*]] = memref.alloc() +// CHECK-NEXT: return %[[SRC]] + +// ----- + +// CHECK-LABEL: func @check_copy_removal_test2 +// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32>, %[[ARG1:.*]]: memref<4xf32>, %[[ARG2:.*]]: memref<4xf32>) +func @check_copy_removal_test2(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xf32>) { + %0 = memref.alloc() : memref<4xf32> + affine.for %arg3 = 0 to 4 { + %5 = affine.load %arg0[%arg3] : memref<4xf32> + %6 = affine.load %arg1[%arg3] : memref<4xf32> + %7 = arith.cmpf "ogt", %5, %6 : f32 + %8 = select %7, %5, %6 : f32 + affine.store %8, %0[%arg3] : memref<4xf32> + } + memref.copy %0, %arg2 : memref<4xf32> to memref<4xf32> + memref.dealloc %0 : memref<4xf32> + return +} +// CHECK-NEXT: %[[SRC:.*]] = memref.alloc() : memref<4xf32> +// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 4 { +// CHECK-NEXT: %[[TMP1:.*]] = affine.load %[[ARG0]][%[[ARG3]]] : memref<4xf32> +// CHECK-NEXT: %[[TMP2:.*]] = affine.load %[[ARG1]][%[[ARG3]]] : memref<4xf32> +// CHECK-NEXT: %[[TMP3:.*]] = arith.cmpf ogt, %[[TMP1]], %[[TMP2]] : f32 +// CHECK-NEXT: %[[TMP4:.*]] = select %[[TMP3]], %[[TMP1]], %[[TMP2]] : f32 +// CHECK-NEXT: affine.store %[[TMP4]], %[[SRC]][%[[ARG3]]] : memref<4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: memref.dealloc %[[SRC]] : memref<4xf32> +// CHECK-NEXT: return + +// ----- + +// CHECK: func @check_copy_removal_test3 +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: index) -> memref<5xf32> +func @check_copy_removal_test3(%arg0 : f32, %index : index) -> memref<5xf32> { + %src = memref.alloc() : memref<5xf32> + %dest = memref.alloc() : memref<5xf32> + affine.store %arg0, %src[%index] : memref<5xf32> + memref.copy %src, %dest : memref<5xf32> to memref<5xf32> + memref.dealloc %src : memref<5xf32> + return %dest : memref<5xf32> +} +// CHECK-NEXT: %[[SRC:.*]] = memref.alloc() +// CHECK-NEXT: affine.store %[[ARG0]], %[[SRC]][%[[ARG1]]] +// CHECK-NEXT: return %[[SRC]] + +// ----- + +// CHECK-LABEL: func @check_copy_removal_test4 +// CHECK-SAME: (%[[ARG0:.*]]: i1) +func @check_copy_removal_test4(%cond : i1) -> memref<5xf32> { + %0 = memref.alloc() : memref<5xf32> + cond_br %cond, ^bb1, ^bb2 +^bb1: +br ^bb3(%0 : memref<5xf32>) +^bb2: + %temp = memref.alloc() : memref<5xf32> + memref.copy %0, %temp : memref<5xf32> to memref<5xf32> + memref.dealloc %0 : memref<5xf32> + br ^bb3(%temp : memref<5xf32>) +^bb3(%res : memref<5xf32>): + return %res : memref<5xf32> +} +// CHECK-NEXT: %[[ZERO:.*]] = memref.alloc() : memref<5xf32> +// CHECK-NEXT: cond_br %[[ARG0]], ^bb1, ^bb2 +// CHECK-NEXT: ^bb1: // pred: ^bb0 +// CHECK-NEXT: br ^bb3(%[[ZERO]] : memref<5xf32>) +// CHECK-NEXT: ^bb2: // pred: ^bb0 +// CHECK-NEXT: br ^bb3(%[[ZERO]] : memref<5xf32>) +// CHECK-NEXT: ^bb3(%[[ONE:.*]]: memref<5xf32>): // 2 preds: ^bb1, ^bb2 +// CHECK-NEXT: return %[[ONE]] : memref<5xf32> + +// ----- + +// CHECK-LABEL: func @check_copy_removal_test5 +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[C0_0:.*]]: index, %[[C0_1:.*]]: index) +func @check_copy_removal_test5(%arg0: index, %arg1: index, %c0_0: index, %c0_1: index) -> memref { + %0 = arith.cmpi "eq", %arg0, %arg1 : index + %1 = memref.alloc(%arg0, %arg0) : memref + %2 = scf.if %0 -> (memref) { + %3 = scf.if %0 -> (memref) { + %7 = memref.dim %1, %c0_0 : memref + %8 = memref.dim %1, %c0_1 : memref + %9 = memref.alloc(%7, %8) : memref + memref.copy %1, %9 : memref to memref + scf.yield %9 : memref + } else { + %7 = memref.alloc(%arg0, %arg1) : memref + %8 = memref.dim %7, %c0_0 : memref + %9 = memref.dim %7, %c0_1 : memref + %10 = memref.alloc(%8, %9) : memref + memref.copy %7, %10 : memref to memref + memref.dealloc %7 : memref + scf.yield %10 : memref + } + %4 = memref.dim %3, %c0_0 : memref + %5 = memref.dim %3, %c0_1 : memref + %6 = memref.alloc(%4, %5) : memref + memref.copy %3, %6 : memref to memref + memref.dealloc %3 : memref + scf.yield %6 : memref + } else { + %3 = memref.alloc(%arg1, %arg1) : memref + %4 = memref.dim %3, %c0_0 : memref + %5 = memref.dim %3, %c0_1 : memref + %6 = memref.alloc(%4, %5) : memref + memref.copy %3, %6 : memref to memref + memref.dealloc %3 : memref + scf.yield %6 : memref + } + memref.dealloc %1 : memref + return %2 : memref +} +// CHECK-NEXT: %[[TMP0:.*]] = arith.cmpi eq, %[[ARG0]], %[[ARG1]] : index +// CHECK-NEXT: %[[TMP1:.*]] = memref.alloc(%[[ARG0]], %[[ARG0]]) : memref +// CHECK-NEXT: %[[TMP2:.*]] = scf.if %[[TMP0]] -> (memref) { +// CHECK-NEXT: %[[TMP3:.*]] = scf.if %[[TMP0]] -> (memref) { +// CHECK-NEXT: %[[TMP6:.*]] = memref.dim %[[TMP1]], %[[C0_0]] : memref +// CHECK-NEXT: %[[TMP7:.*]] = memref.dim %[[TMP1]], %[[C0_1]] : memref +// CHECK-NEXT: scf.yield %[[TMP1]] : memref +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[TMP6:.*]] = memref.alloc(%[[ARG0]], %[[ARG1]]) : memref +// CHECK-NEXT: %[[TMP7:.*]] = memref.dim %[[TMP6]], %[[C0_0]] : memref +// CHECK-NEXT: %[[TMP8:.*]] = memref.dim %[[TMP6]], %[[C0_1]] : memref +// CHECK-NEXT: scf.yield %[[TMP6]] : memref +// CHECK-NEXT: } +// CHECK-NEXT: %[[TMP4:.*]] = memref.dim %[[TMP3]], %[[C0_0]] : memref +// CHECK-NEXT: %[[TMP5:.*]] = memref.dim %[[TMP3]], %[[C0_1]] : memref +// CHECK-NEXT: scf.yield %[[TMP3]] : memref +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[TMP3:.*]] = memref.alloc(%[[ARG1]], %[[ARG1]]) : memref +// CHECK-NEXT: %[[TMP4:.*]] = memref.dim %[[TMP3]], %[[C0_0]] : memref +// CHECK-NEXT: %[[TMP5:.*]] = memref.dim %[[TMP3]], %[[C0_1]] : memref +// CHECK-NEXT: scf.yield %[[TMP3]] : memref +// CHECK-NEXT: } +// CHECK-NEXT: memref.dealloc %[[TMP1]] : memref +// CHECK-NEXT: return %[[TMP2]] : memref + +// ----- + +// CHECK-LABEL: func @check_copy_removal_test6 +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: memref<2xf32>, %[[ARG4:.*]]: memref<2xf32>) +func @check_copy_removal_test6(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<2xf32>, %arg4: memref<2xf32>) { + %0 = memref.alloc() : memref<2xf32> + memref.dealloc %0 : memref<2xf32> + %1 = memref.alloc() : memref<2xf32> + memref.copy %arg3, %1 : memref<2xf32> to memref<2xf32> + %2 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %1) -> (memref<2xf32>) { + %3 = arith.cmpi "eq", %arg5, %arg1 : index + memref.dealloc %arg6 : memref<2xf32> + %4 = memref.alloc() : memref<2xf32> + %5 = memref.alloc() : memref<2xf32> + memref.copy %4, %5 : memref<2xf32> to memref<2xf32> + memref.dealloc %4 : memref<2xf32> + %6 = memref.alloc() : memref<2xf32> + memref.copy %5, %6 : memref<2xf32> to memref<2xf32> + scf.yield %6 : memref<2xf32> + } + memref.copy %2, %arg4 : memref<2xf32> to memref<2xf32> + memref.dealloc %2 : memref<2xf32> + return +} +// CHECK-NEXT: %[[TMP0:.*]] = memref.alloc() +// CHECK-NEXT: memref.dealloc %[[TMP0]] +// CHECK-NEXT: %[[TMP1:.*]] = scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] iter_args(%[[ARG6:.*]] = %[[ARG3]]) +// CHECK-NEXT: %[[TMP2:.*]] = arith.cmpi eq, %[[IV]], %[[ARG1]] +// CHECK-NEXT: memref.dealloc %[[ARG6]] +// CHECK-NEXT: %[[TMP3:.*]] = memref.alloc() +// CHECK-NEXT: scf.yield %[[TMP3]] +// CHECK-NEXT: } +// CHECK-NEXT: memref.dealloc %[[TMP1]] +// CHECK-NEXT: return + +// ----- + +// CHECK-LABEL: func @check_copy_removal_test7 +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: index) +func @check_copy_removal_test7(%arg0 : f32, %arg1 : index) -> memref<5xf32> { + %src = memref.alloc() : memref<5xf32> + %dest = memref.alloc() : memref<5xf32> + memref.copy %src, %dest : memref<5xf32> to memref<5xf32> + %temp = affine.load %src[%arg1] : memref<5xf32> + memref.dealloc %src : memref<5xf32> + return %dest : memref<5xf32> +} +// CHECK-NEXT: %[[SRC:.*]] = memref.alloc() : memref<5xf32> +// CHECK-NEXT: %[[TMP:.*]] = affine.load %[[SRC]][%[[ARG1]]] : memref<5xf32> +// CHECK-NEXT: return %[[SRC]] : memref<5xf32> + +// ----- + +// CHECK-LABEL: func @check_copy_removal_test9 +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: index) +func @check_copy_removal_test9(%arg0 : f32, %arg1 : index) { + %src = memref.alloc() : memref<5xf32> + %dest = memref.alloc() : memref<5xf32> + %temp = affine.load %dest[%arg1] : memref<5xf32> + memref.copy %src, %dest : memref<5xf32> to memref<5xf32> + memref.dealloc %src : memref<5xf32> + memref.dealloc %dest : memref<5xf32> + return +} +// CHECK-NEXT: %[[SRC:.*]] = memref.alloc() : memref<5xf32> +// CHECK-NEXT: %[[DEST:.*]] = memref.alloc() : memref<5xf32> +// CHECK-NEXT: %[[TMP:.*]] = affine.load %[[DEST]][%[[ARG1]]] : memref<5xf32> +// CHECK-NEXT: memref.dealloc %[[SRC]] : memref<5xf32> +// CHECK-NEXT: memref.dealloc %[[DEST]] : memref<5xf32> +// CHECK-NEXT: return + +// ----- + +// CHECK-LABEL: func @check_copy_removal_test10 +// CHECK-SAME: (%[[COND:.*]]: i1, %[[TEMP:.*]]: memref<5xf32>, %[[VAL:.*]]: f32, %[[INDEX:.*]]: index) +func @check_copy_removal_test10(%cond : i1, %temp : memref<5xf32>, %val : f32, %index : index) -> memref<5xf32> { +%0 = memref.alloc() : memref<5xf32> +cond_br %cond, ^bb1, ^bb2 +^bb1: +br ^bb3(%0 : memref<5xf32>) +^bb2: +memref.copy %0, %temp : memref<5xf32> to memref<5xf32> +memref.dealloc %0 : memref<5xf32> +br ^bb3(%temp : memref<5xf32>) +^bb3(%res : memref<5xf32>): +memref.store %val, %res[%index] : memref<5xf32> +return %res : memref<5xf32> +} +// CHECK-NEXT: %[[SRC:.*]] = memref.alloc() : memref<5xf32> +// CHECK-NEXT: cond_br %[[COND]], ^bb1, ^bb2 +// CHECK-NEXT: ^bb1: // pred: ^bb0 +// CHECK-NEXT: br ^bb3(%[[SRC]] : memref<5xf32>) +// CHECK-NEXT: ^bb2: // pred: ^bb0 +// CHECK-NEXT: br ^bb3(%[[SRC]] : memref<5xf32>) +// CHECK-NEXT: ^bb3(%[[ARG:.*]]: memref<5xf32>): // 2 preds: ^bb1, ^bb2 +// CHECK-NEXT: memref.store %[[VAL]], %[[ARG]][%[[INDEX]]] : memref<5xf32> +// CHECK-NEXT: return %[[ARG]] : memref<5xf32> + +// ----- + +// CHECK-LABEL: func @do_not_remove_copy_test_1 +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: index) +func @do_not_remove_copy_test_1(%arg0 : f32, %arg2 : index) -> memref<5xf32> { + %src = memref.alloc() : memref<5xf32> + %dest = memref.alloc() : memref<5xf32> + memref.copy %src, %dest : memref<5xf32> to memref<5xf32> + affine.store %arg0, %src[%arg2] : memref<5xf32> + memref.dealloc %src : memref<5xf32> + return %dest : memref<5xf32> +} +// CHECK-NEXT: %[[SRC:.*]] = memref.alloc() : memref<5xf32> +// CHECK-NEXT: %[[DEST:.*]] = memref.alloc() : memref<5xf32> +// CHECK-NEXT: memref.copy %[[SRC]], %[[DEST]] : memref<5xf32> to memref<5xf32> +// CHECK-NEXT: affine.store %[[ARG0]], %[[SRC]][%[[ARG1]]] : memref<5xf32> +// CHECK-NEXT: memref.dealloc %[[SRC]] : memref<5xf32> +// CHECK-NEXT: return %[[DEST]] : memref<5xf32> + +// ----- + +// CHECK-LABEL: func @do_not_remove_copy_test_2 +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: index) +func @do_not_remove_copy_test_2(%arg0 : f32, %arg1 : index) -> memref<5xf32> { + %src = memref.alloc() : memref<5xf32> + %dest = memref.alloc() : memref<5xf32> + affine.store %arg0, %dest[%arg1] : memref<5xf32> + memref.copy %src, %dest : memref<5xf32> to memref<5xf32> + memref.dealloc %src : memref<5xf32> + return %dest : memref<5xf32> +} +// CHECK-NEXT: %[[SRC:.*]] = memref.alloc() : memref<5xf32> +// CHECK-NEXT: %[[DEST:.*]] = memref.alloc() : memref<5xf32> +// CHECK-NEXT: affine.store %[[ARG0]], %[[DEST]][%[[ARG1]]] : memref<5xf32> +// CHECK-NEXT: memref.copy %[[SRC]], %[[DEST]] : memref<5xf32> to memref<5xf32> +// CHECK-NEXT: memref.dealloc %[[SRC]] : memref<5xf32> +// CHECK-NEXT: return %[[DEST]] : memref<5xf32> + +// ----- + +// CHECK-LABEL: func @do_not_remove_copy_test_3 +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: index) +func @do_not_remove_copy_test_3(%arg0 : f32, %arg1 : index) -> memref<5xf32> { + %src = memref.alloc() : memref<5xf32> + %dest = memref.alloc() : memref<5xf32> + %temp = affine.load %dest[%arg1] : memref<5xf32> + memref.copy %src, %dest : memref<5xf32> to memref<5xf32> + memref.dealloc %src : memref<5xf32> + return %dest : memref<5xf32> +} +// CHECK-NEXT: %[[SRC:.*]] = memref.alloc() : memref<5xf32> +// CHECK-NEXT: %[[DEST:.*]] = memref.alloc() : memref<5xf32> +// CHECK-NEXT: %[[TMP:.*]] = affine.load %[[DEST]][%[[ARG1]]] : memref<5xf32> +// CHECK-NEXT: memref.copy %[[SRC]], %[[DEST]] : memref<5xf32> to memref<5xf32> +// CHECK-NEXT: memref.dealloc %[[SRC]] : memref<5xf32> +// CHECK-NEXT: return %[[DEST]] : memref<5xf32>