diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -450,16 +450,21 @@ Value source = cloneOp.getInput(); - // This only finds dealloc operations for the immediate value. It should - // also consider aliases. That would also make the safety check below - // redundant. + // We aim to check dealloc op for aliases below, e.g. ReshapeOp + // which otherwise could prevent removal of unnecessary allocs. + Operation *source_op = cloneOp.getInput().getDefiningOp(); + if (source_op) + while (auto iface = dyn_cast(source_op)) + source_op = iface.getViewSource().getDefiningOp(); + llvm::Optional maybeCloneDeallocOp = memref::findDealloc(cloneOp.getOutput()); // Skip if either of them has > 1 deallocate operations. if (!maybeCloneDeallocOp.has_value()) return failure(); + Value source_val = source_op ? source_op->getResult(0) : source; llvm::Optional maybeSourceDeallocOp = - memref::findDealloc(source); + memref::findDealloc(source_val); if (!maybeSourceDeallocOp.has_value()) return failure(); Operation *cloneDeallocOp = *maybeCloneDeallocOp; diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -256,3 +256,18 @@ // CHECK: %[[T0:.+]] = bufferization.alloc_tensor() : tensor<4x5x6xf32> // CHECK: %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32> // CHECK: return %[[T1]] + +// ----- + +func.func @dealloc_canonicalize_clone_removal(%arg0: memref) -> memref<*xf32> { + %c1 = arith.constant 1 : index + %0 = memref.alloc(%c1) : memref + %1 = memref.reshape %0(%arg0) : (memref, memref) -> memref<*xf32> + %2 = bufferization.clone %1 : memref<*xf32> to memref<*xf32> + memref.dealloc %0 : memref + return %2 : memref<*xf32> +} +// CHECK-LABEL: @dealloc_canonicalize_clone_removal +// CHECK-NOT: bufferization.clone +// CHECK-NOT: memref.dealloc +// CHECK: return {{.*}}