diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -449,17 +449,20 @@ } Value source = cloneOp.getInput(); + // Aims to find the dealloc op for the canonical source + // which otherwise could prevent removal of unnecessary allocs. + Value canonicalSource = source; + while (auto iface = dyn_cast_or_null( + canonicalSource.getDefiningOp())) + canonicalSource = iface.getViewSource(); - // This only finds dealloc operations for the immediate value. It should - // also consider aliases. That would also make the safety check below - // redundant. llvm::Optional maybeCloneDeallocOp = memref::findDealloc(cloneOp.getOutput()); // Skip if either of them has > 1 deallocate operations. if (!maybeCloneDeallocOp.has_value()) return failure(); llvm::Optional maybeSourceDeallocOp = - memref::findDealloc(source); + memref::findDealloc(canonicalSource); if (!maybeSourceDeallocOp.has_value()) return failure(); Operation *cloneDeallocOp = *maybeCloneDeallocOp; diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -256,3 +256,18 @@ // CHECK: %[[T0:.+]] = bufferization.alloc_tensor() : tensor<4x5x6xf32> // CHECK: %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32> // CHECK: return %[[T1]] + +// ----- + +func.func @dealloc_canonicalize_clone_removal(%arg0: memref) -> memref<*xf32> { + %c1 = arith.constant 1 : index + %0 = memref.alloc(%c1) : memref + %1 = memref.reshape %0(%arg0) : (memref, memref) -> memref<*xf32> + %2 = bufferization.clone %1 : memref<*xf32> to memref<*xf32> + memref.dealloc %0 : memref + return %2 : memref<*xf32> +} +// CHECK-LABEL: @dealloc_canonicalize_clone_removal +// CHECK-NOT: bufferization.clone +// CHECK-NOT: memref.dealloc +// CHECK: return {{.*}}