diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp @@ -177,9 +177,37 @@ Operation *startOperation = BufferPlacementAllocs::getStartOperation( allocValue, placementBlock, liveness); - // Move the alloc in front of the start operation. + // Find the corresponding dealloc (if any). Operation *allocOperation = allocValue.getDefiningOp(); + Operation *deallocOperation = nullptr; + bool isSupported = [&]() { + for (OpResult allocResult : allocOperation->getResults()) { + for (OpOperand &allocUse : allocResult.getUses()) { + Operation *deallocCandidate = allocUse.getOwner(); + if (StateT::isMatchingDealloc(allocOperation, deallocCandidate)) { + // Multiple deallocs are not supported. + if (deallocOperation != nullptr && + deallocOperation != deallocCandidate) + return false; + // Deallocs in different blocks are not supported. + if (allocOperation->getBlock() != deallocCandidate->getBlock()) + return false; + deallocOperation = deallocCandidate; + } + } + } + return true; + }(); + + // Skip unsupported allocs. + if (!isSupported) + continue; + + // Move alloc and dealloc. allocOperation->moveBefore(startOperation); + if (deallocOperation) + deallocOperation->moveBefore( + allocOperation->getBlock()->getTerminator()); } } @@ -280,6 +308,13 @@ return llvm::isa(op); } + /// Returns true if the `deallocOp` is a deallocation of a result of + /// `allocOp`. One of the operands of `deallocOp` is guaranteed to be a result + /// of the `allocOp`. + static bool isMatchingDealloc(Operation *allocOp, Operation *deallocOp) { + return llvm::isa(deallocOp); + } + /// Sets the current placement block to the given block. void recordMoveToDominator(Block *block) { placementBlock = block; } @@ -317,6 +352,17 @@ return llvm::isa(op); } + /// Returns true if the `deallocOp` is a deallocation of a result of + /// `allocOp`. One of the operands of `deallocOp` is guaranteed to be a result + /// of the `allocOp`. + static bool isMatchingDealloc(Operation *allocOp, Operation *deallocOp) { + if (!llvm::isa(deallocOp)) + return false; + assert(llvm::isa(allocOp) && + "no deallocs expected for alloca allocations"); + return true; + } + /// Does not change the internal placement block, as we want to move /// operations out of loops only. void recordMoveToDominator(Block *block) {} diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-hoisting.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-hoisting.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-hoisting.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-hoisting.mlir @@ -307,8 +307,7 @@ // ----- -// Test Case: Invalid position of the DeallocOp. There is a user after -// deallocation. +// Test Case: // bb0 // / \ // bb1 bb2 <- Initial position of AllocOp @@ -328,7 +327,6 @@ ^bb2: %1 = memref.alloc() : memref<2xf32> test.buffer_based in(%arg0: memref<2xf32>) out(%1: memref<2xf32>) - memref.dealloc %1 : memref<2xf32> cf.br ^exit(%1 : memref<2xf32>) ^exit(%arg2: memref<2xf32>): test.copy(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-loop-hoisting.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-loop-hoisting.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-loop-hoisting.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-loop-hoisting.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -buffer-loop-hoisting -split-input-file %s | FileCheck %s +// RUN: mlir-opt -buffer-loop-hoisting -allow-unregistered-dialect -split-input-file %s | FileCheck %s // This file checks the behavior of BufferLoopHoisting pass for moving Alloc // operations in their correct positions. @@ -487,3 +487,51 @@ // CHECK: %[[ALLOCA0:.*]] = memref.alloca({{.*}}) // CHECK-NEXT: %[[ALLOCA1:.*]] = memref.alloca({{.*}}) // CHECK-NEXT: {{.*}} = scf.for + +// ----- + +// CHECK-LABEL: func @hoist_dealloc +func.func @hoist_dealloc(%lb: index, %ub: index, %step: index, %f: f32) { + scf.for %iv = %lb to %ub step %step { + %0 = memref.alloc() : memref<5xf32> + memref.store %f, %0[%iv] : memref<5xf32> + "print_memref"(%0) : (memref<5xf32>) -> () + memref.dealloc %0 : memref<5xf32> + } + return +} + +// CHECK: memref.alloc +// CHECK: scf.for {{.*}} { +// CHECK: memref.store +// CHECK: } +// CHECK: memref.dealloc +// CHECK: return + +// ----- + +// CHECK-LABEL: func @unsupported_hoist_dealloc +func.func @unsupported_hoist_dealloc(%lb: index, %ub: index, %step: index, %f: f32) { + scf.for %iv = %lb to %ub step %step { + %0 = memref.alloc() : memref<5xf32> + memref.store %f, %0[%iv] : memref<5xf32> + "print_memref"(%0) : (memref<5xf32>) -> () + %c = "some_condition"() : () -> (i1) + scf.if %c { + // Hoisting not supported because the dealloc appears in a block that is + // different from the alloc block. + memref.dealloc %0 : memref<5xf32> + } + } + return +} + +// CHECK: scf.for {{.*}} { +// CHECK: memref.alloc +// CHECK: memref.store +// CHECK: scf.if {{.*}} { +// CHECK: memref.dealloc +// CHECK: } +// CHECK: } +// CHECK-NOT: memref.dealloc +// CHECK: return