diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -450,20 +450,29 @@ Value source = cloneOp.getInput(); - // This only finds dealloc operations for the immediate value. It should - // also consider aliases. That would also make the safety check below - // redundant. + // We aim to check dealloc op for aliases below, e.g. ReshapeOp + // which otherwise could prevent removal of unnecessary allocs. + Operation *sourceDeallocOp; + Value source = cloneOp.getInput(); + Operation *source_op = cloneOp.getInput().getDefiningOp(); + while (auto iface = dyn_cast(source_op)) + source_op = iface.getViewSource().getDefiningOp(); + if (isa(source_op)) { + llvm::Optional tmpDeallocOp = + memref::findDealloc(source_op->getResult(0)); + if (!tmpDeallocOp.has_value()) + return failure(); + Operation *tmpOp = *tmpDeallocOp; + if (tmpOp) { + sourceDeallocOp = tmpOp; + } + } llvm::Optional maybeCloneDeallocOp = memref::findDealloc(cloneOp.getOutput()); // Skip if either of them has > 1 deallocate operations. if (!maybeCloneDeallocOp.has_value()) return failure(); - llvm::Optional maybeSourceDeallocOp = - memref::findDealloc(source); - if (!maybeSourceDeallocOp.has_value()) - return failure(); Operation *cloneDeallocOp = *maybeCloneDeallocOp; - Operation *sourceDeallocOp = *maybeSourceDeallocOp; // If both are deallocated in the same block, their in-block lifetimes // might not fully overlap, so we cannot decide which one to drop. diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -256,3 +256,18 @@ // CHECK: %[[T0:.+]] = bufferization.alloc_tensor() : tensor<4x5x6xf32> // CHECK: %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32> // CHECK: return %[[T1]] + +func.func @Atanh_platform_elem_type_output_type(%arg0: memref<*xf32>) -> memref<*xf32> { + %0 = memref.rank %arg0 : memref<*xf32> + %1 = memref.alloca(%0) : memref + %c1 = arith.constant 1 : index + %2 = memref.alloc(%c1) : memref + %3 = memref.reshape %2(%1) : (memref, memref) -> memref<*xf32> + %4 = bufferization.clone %3 : memref<*xf32> to memref<*xf32> + memref.dealloc %2 : memref + return %4 : memref<*xf32> +} +// CHECK-LABEL: @Atanh_platform_elem_type_output_type +// CHECK-NOT: bufferization.clone +// CHECK-NOT: memref.dealloc +// CHECK: return {{.*}}