diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -453,20 +453,18 @@ // We aim to check dealloc op for aliases below, e.g. ReshapeOp // which otherwise could prevent removal of unnecessary allocs. Operation *sourceDeallocOp; - Value source = cloneOp.getInput(); Operation *source_op = cloneOp.getInput().getDefiningOp(); while (auto iface = dyn_cast(source_op)) source_op = iface.getViewSource().getDefiningOp(); - if (isa(source_op)) { - llvm::Optional tmpDeallocOp = - memref::findDealloc(source_op->getResult(0)); - if (!tmpDeallocOp.has_value()) - return failure(); - Operation *tmpOp = *tmpDeallocOp; - if (tmpOp) { - sourceDeallocOp = tmpOp; - } + llvm::Optional tmpDeallocOp = + memref::findDealloc(source_op->getResult(0)); + if (!tmpDeallocOp.has_value()) + return failure(); + Operation *tmpOp = *tmpDeallocOp; + if (tmpOp) { + sourceDeallocOp = tmpOp; } + llvm::Optional maybeCloneDeallocOp = memref::findDealloc(cloneOp.getOutput()); // Skip if either of them has > 1 deallocate operations. diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -257,17 +257,17 @@ // CHECK: %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32> // CHECK: return %[[T1]] -func.func @Atanh_platform_elem_type_output_type(%arg0: memref<*xf32>) -> memref<*xf32> { - %0 = memref.rank %arg0 : memref<*xf32> - %1 = memref.alloca(%0) : memref +// ----- + +func.func @dealloc_canonicalize_clone_removal(%arg0: memref) -> memref<*xf32> { %c1 = arith.constant 1 : index - %2 = memref.alloc(%c1) : memref - %3 = memref.reshape %2(%1) : (memref, memref) -> memref<*xf32> - %4 = bufferization.clone %3 : memref<*xf32> to memref<*xf32> - memref.dealloc %2 : memref - return %4 : memref<*xf32> + %0 = memref.alloc(%c1) : memref + %1 = memref.reshape %0(%arg0) : (memref, memref) -> memref<*xf32> + %2 = bufferization.clone %1 : memref<*xf32> to memref<*xf32> + memref.dealloc %0 : memref + return %2 : memref<*xf32> } -// CHECK-LABEL: @Atanh_platform_elem_type_output_type +// CHECK-LABEL: @dealloc_canonicalize_clone_removal // CHECK-NOT: bufferization.clone // CHECK-NOT: memref.dealloc // CHECK: return {{.*}}