diff --git a/mlir/lib/Transforms/CopyRemoval.cpp b/mlir/lib/Transforms/CopyRemoval.cpp --- a/mlir/lib/Transforms/CopyRemoval.cpp +++ b/mlir/lib/Transforms/CopyRemoval.cpp @@ -30,16 +30,35 @@ reuseCopySourceAsTarget(copyOp); reuseCopyTargetAsSource(copyOp); }); + for (std::pair &pair : replaceList) + pair.first.replaceAllUsesWith(pair.second); for (Operation *op : eraseList) op->erase(); } private: /// List of operations that need to be removed. - DenseSet eraseList; + llvm::SmallPtrSet eraseList; + + /// List of values that need to be replaced with their counterparts. + llvm::SmallDenseSet, 4> replaceList; + + /// Returns the allocation operation for `value` in `block` if it exists. + /// nullptr otherwise. + Operation *getAllocationOpInBlock(Value value, Block *block) { + assert(block && "Block cannot be null"); + Operation *op = value.getDefiningOp(); + if (op && op->getBlock() == block) { + auto effects = dyn_cast(op); + if (effects && effects.hasEffect()) + return op; + } + return nullptr; + } /// Returns the deallocation operation for `value` in `block` if it exists. - Operation *getDeallocationInBlock(Value value, Block *block) { + /// nullptr otherwise. + Operation *getDeallocationOpInBlock(Value value, Block *block) { assert(block && "Block cannot be null"); auto valueUsers = value.getUsers(); auto it = llvm::find_if(valueUsers, [&](Operation *op) { @@ -119,9 +138,10 @@ Value to = copyOp.getTarget(); Operation *copy = copyOp.getOperation(); + Block *copyBlock = copy->getBlock(); Operation *fromDefiningOp = from.getDefiningOp(); - Operation *fromFreeingOp = getDeallocationInBlock(from, copy->getBlock()); - Operation *toDefiningOp = to.getDefiningOp(); + Operation *fromFreeingOp = getDeallocationOpInBlock(from, copyBlock); + Operation *toDefiningOp = getAllocationOpInBlock(to, copyBlock); if (!fromDefiningOp || !fromFreeingOp || !toDefiningOp || !areOpsInTheSameBlock({fromFreeingOp, toDefiningOp, copy}) || hasUsersBetween(to, toDefiningOp, copy) || @@ -129,7 +149,7 @@ hasMemoryEffectOpBetween(copy, fromFreeingOp)) return; - to.replaceAllUsesWith(from); + replaceList.insert({to, from}); eraseList.insert(copy); eraseList.insert(toDefiningOp); eraseList.insert(fromFreeingOp); @@ -169,8 +189,9 @@ Value to = copyOp.getTarget(); Operation *copy = copyOp.getOperation(); - Operation *fromDefiningOp = from.getDefiningOp(); - Operation *fromFreeingOp = getDeallocationInBlock(from, copy->getBlock()); + Block *copyBlock = copy->getBlock(); + Operation *fromDefiningOp = getAllocationOpInBlock(from, copyBlock); + Operation *fromFreeingOp = getDeallocationOpInBlock(from, copyBlock); if (!fromDefiningOp || !fromFreeingOp || !areOpsInTheSameBlock({fromFreeingOp, fromDefiningOp, copy}) || hasUsersBetween(to, fromDefiningOp, copy) || @@ -178,7 +199,7 @@ hasMemoryEffectOpBetween(copy, fromFreeingOp)) return; - from.replaceAllUsesWith(to); + replaceList.insert({from, to}); eraseList.insert(copy); eraseList.insert(fromDefiningOp); eraseList.insert(fromFreeingOp); diff --git a/mlir/test/Transforms/copy-removal.mlir b/mlir/test/Transforms/copy-removal.mlir --- a/mlir/test/Transforms/copy-removal.mlir +++ b/mlir/test/Transforms/copy-removal.mlir @@ -283,3 +283,67 @@ dealloc %temp : memref<2xf32> return } + +// ----- + +// The only redundant copy is linalg.copy(%4, %5) + +// CHECK-LABEL: func @loop_alloc +func @loop_alloc(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<2xf32>, %arg4: memref<2xf32>) { + // CHECK: %{{.*}} = alloc() + %0 = alloc() : memref<2xf32> + dealloc %0 : memref<2xf32> + // CHECK: %{{.*}} = alloc() + %1 = alloc() : memref<2xf32> + // CHECK: linalg.copy + linalg.copy(%arg3, %1) : memref<2xf32>, memref<2xf32> + %2 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %1) -> (memref<2xf32>) { + %3 = cmpi "eq", %arg5, %arg1 : index + // CHECK: dealloc + dealloc %arg6 : memref<2xf32> + // CHECK: %[[PERCENT4:.*]] = alloc() + %4 = alloc() : memref<2xf32> + // CHECK-NOT: alloc + // CHECK-NOT: linalg.copy + // CHECK-NOT: dealloc + %5 = alloc() : memref<2xf32> + linalg.copy(%4, %5) : memref<2xf32>, memref<2xf32> + dealloc %4 : memref<2xf32> + // CHECK: %[[PERCENT6:.*]] = alloc() + %6 = alloc() : memref<2xf32> + // CHECK: linalg.copy(%[[PERCENT4]], %[[PERCENT6]]) + linalg.copy(%5, %6) : memref<2xf32>, memref<2xf32> + scf.yield %6 : memref<2xf32> + } + // CHECK: linalg.copy + linalg.copy(%2, %arg4) : memref<2xf32>, memref<2xf32> + dealloc %2 : memref<2xf32> + return +} + +// ----- + +// The linalg.copy operation can be removed in addition to alloc and dealloc +// operations. All uses of %0 is then replaced with %arg2. + +// CHECK-LABEL: func @check_with_affine_dialect +func @check_with_affine_dialect(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xf32>) { + // CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32>, %[[ARG1:.*]]: memref<4xf32>, %[[RES:.*]]: memref<4xf32>) + // CHECK-NOT: alloc + %0 = alloc() : memref<4xf32> + affine.for %arg3 = 0 to 4 { + %5 = affine.load %arg0[%arg3] : memref<4xf32> + %6 = affine.load %arg1[%arg3] : memref<4xf32> + %7 = cmpf "ogt", %5, %6 : f32 + // CHECK: %[[SELECT_RES:.*]] = select + %8 = select %7, %5, %6 : f32 + // CHECK-NEXT: affine.store %[[SELECT_RES]], %[[RES]] + affine.store %8, %0[%arg3] : memref<4xf32> + } + // CHECK-NOT: linalg.copy + // CHECK-NOT: dealloc + "linalg.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> () + dealloc %0 : memref<4xf32> + //CHECK: return + return +}