diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -498,32 +498,47 @@ Value source = cloneOp.input(); - // Removes the clone operation and the corresponding dealloc and alloc - // operation (if any). - auto tryRemoveClone = [&](Operation *sourceOp, Operation *dealloc, - Operation *alloc) { - if (!sourceOp || !dealloc || !alloc || - alloc->getBlock() != dealloc->getBlock()) - return false; - rewriter.replaceOp(cloneOp, source); - rewriter.eraseOp(dealloc); - return true; - }; - - // Removes unnecessary clones that are derived from the result of the clone - // op. - Operation *deallocOp = findDealloc(cloneOp.output()); - Operation *sourceOp = source.getDefiningOp(); - if (tryRemoveClone(sourceOp, deallocOp, sourceOp)) - return success(); + // This only finds dealloc operations for the immediate value. It should + // also consider aliases. That would also make the safety check below + // redundant. + Operation *cloneDeallocOp = findDealloc(cloneOp.output()); + Operation *sourceDeallocOp = findDealloc(source); + + // If both are deallocated in the same block, their in-block lifetimes + // might not fully overlap, so we cannot decide which one to drop. + if (cloneDeallocOp && sourceDeallocOp && + cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) + return failure(); - // Removes unnecessary clones that are derived from the source of the clone - // op. - deallocOp = findDealloc(source); - if (tryRemoveClone(sourceOp, deallocOp, cloneOp)) - return success(); + Block *currentBlock = cloneOp->getBlock(); + Operation *redundantDealloc = nullptr; + if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { + redundantDealloc = cloneDeallocOp; + } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { + redundantDealloc = sourceDeallocOp; + } - return failure(); + if (!redundantDealloc) + return failure(); + + // Safety check that there are no other deallocations inbetween + // cloneOp and redundantDealloc, as otherwise we might deallocate an alias + // of source before the uses of the clone. With alias information, we could + // restrict this to only fail of the dealloc's operand is an alias + // of the source. + for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; + pos = pos->getNextNode()) { + auto effectInterface = dyn_cast(pos); + if (!effectInterface) + continue; + if (effectInterface.hasEffect()) + return failure(); + } + + rewriter.replaceOpWithNewOp(cloneOp, cloneOp.getType(), + source); + rewriter.eraseOp(redundantDealloc); + return success(); } }; diff --git a/mlir/lib/Transforms/BufferDeallocation.cpp b/mlir/lib/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Transforms/BufferDeallocation.cpp @@ -151,8 +151,10 @@ } // Recurse into all distinct regions and check for explicit control-flow // loops. - for (Region ®ion : op->getRegions()) - recurse(region.front(), current); + for (Region ®ion : op->getRegions()) { + if (!region.empty()) + recurse(region.front(), current); + } } /// Recurses into explicit control-flow structures that are given by diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -canonicalize --split-input-file | FileCheck %s +// RUN: mlir-opt %s -canonicalize --split-input-file -allow-unregistered-dialect | FileCheck %s // Test case: Basic folding of memref.tensor_load(memref.buffer_cast(t)) -> t // CHECK-LABEL: func @tensor_load_of_buffer_cast( @@ -128,4 +128,67 @@ // CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] // CHECK-SAME: : memref to memref<4x?xf32 // CHECK: %[[RESULT:.+]] = memref.cast %[[SUBVIEW]] -// CHEKC: return %[[RESULT]] +// CHECK: return %[[RESULT]] + +// ----- + +// CHECK-LABEL: @clone_before_dealloc +// CHECK-SAME: %[[ARG:.*]]: memref +func @clone_before_dealloc(%arg0: memref) -> memref { + // CHECK-NEXT: return %[[ARG]] + %0 = memref.clone %arg0 : memref to memref + memref.dealloc %arg0 : memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: @clone_before_dealloc +// CHECK-SAME: %[[ARG:.*]]: memref +func @clone_before_dealloc(%arg0: memref) -> memref { + // CHECK-NEXT: "use"(%arg0) + // CHECK-NEXT: return %[[ARG]] + %0 = memref.clone %arg0 : memref to memref + "use"(%0) : (memref) -> () + memref.dealloc %0 : memref + return %arg0 : memref +} + +// ----- + +// CHECK-LABEL: @clone_after_cast +// CHECK-SAME: %[[ARG:.*]]: memref +func @clone_after_cast(%arg0: memref) -> memref<32xf32> { + // CHECK-NEXT: memref.clone %[[ARG]] : memref to memref<32xf32> + // CHECK-NOT: memref.cast + %0 = memref.cast %arg0 : memref to memref<32xf32> + %1 = memref.clone %0 : memref<32xf32> to memref<32xf32> + return %1 : memref<32xf32> +} + +// ----- + +// CHECK-LABEL: @clone_and_cast +// CHECK-SAME: %[[ARG:.*]]: memref +func @clone_and_cast(%arg0: memref) -> memref<32xf32> { + // CHECK-NEXT: %[[RES:.*]] = memref.cast %[[ARG]] : memref to memref<32xf32> + %0 = memref.clone %arg0 : memref to memref<32xf32> + // CHECK-NEXT: return %[[RES]] + memref.dealloc %arg0 : memref + return %0 : memref<32xf32> +} + +// ----- + +// CHECK-LABEL: @alias_is_freed +func @alias_is_freed(%arg0 : memref) { + // CHECK: memref.clone + // CHECK: memref.dealloc + // CHECK: memref.dealloc + %0 = memref.cast %arg0 : memref to memref<32xf32> + %1 = memref.clone %0 : memref<32xf32> to memref<32xf32> + memref.dealloc %arg0 : memref + "use"(%1) : (memref<32xf32>) -> () + memref.dealloc %1 : memref<32xf32> + return +}