diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -465,6 +465,25 @@ Operation *cloneDeallocOp = *maybeCloneDeallocOp; Operation *sourceDeallocOp = *maybeSourceDeallocOp; + // before checking the operation blocks, we need to make sure the right + // source op is being checked. Some ops like ReshapeOp are aliases to + // allocation ops and could prevent removal of unnecessary allocs. + if(!sourceDeallocOp) { + Operation *source_op = cloneOp->getPrevNode(); + while (dyn_cast(source_op)) { + llvm::Optional tmpDeallocOp = + memref::findDealloc(source_op->getResult(0)); + if (!maybeSourceDeallocOp.has_value()) + return failure(); + Operation *tmpOp = *tmpDeallocOp; + if(tmpOp) { + sourceDeallocOp = tmpOp; + break; + } + source_op = source_op->getPrevNode(); + } + } + // If both are deallocated in the same block, their in-block lifetimes // might not fully overlap, so we cannot decide which one to drop. if (cloneDeallocOp && sourceDeallocOp && diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -256,3 +256,16 @@ // CHECK: %[[T0:.+]] = bufferization.alloc_tensor() : tensor<4x5x6xf32> // CHECK: %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32> // CHECK: return %[[T1]] + +func.func @Atanh_platform_elem_type_output_type(%arg0: memref<*xf32>) -> memref<*xf32> { + %0 = memref.rank %arg0 : memref<*xf32> + %1 = memref.alloca(%0) : memref + %c1 = arith.constant 1 : index + %2 = memref.alloc(%c1) : memref + %3 = memref.reshape %2(%1) : (memref, memref) -> memref<*xf32> + %4 = bufferization.clone %3 : memref<*xf32> to memref<*xf32> + memref.dealloc %2 : memref + return %4 : memref<*xf32> +} +// CHECK-NOT: bufferization.clone +// CHECK-NOT: memref.dealloc \ No newline at end of file