diff --git a/mlir/lib/Transforms/BufferOptimizations.cpp b/mlir/lib/Transforms/BufferOptimizations.cpp --- a/mlir/lib/Transforms/BufferOptimizations.cpp +++ b/mlir/lib/Transforms/BufferOptimizations.cpp @@ -125,12 +125,17 @@ public: BufferAllocationHoisting(Operation *op) : BufferPlacementTransformationBase(op), dominators(op), - postDominators(op) {} + postDominators(op), scopeOp(op) {} /// Moves allocations upwards. void hoist() { - for (BufferPlacementAllocs::AllocEntry &entry : allocs) { - Value allocValue = std::get<0>(entry); + SmallVector allocsAndAllocas; + for (BufferPlacementAllocs::AllocEntry &entry : allocs) + allocsAndAllocas.push_back(std::get<0>(entry)); + scopeOp->walk( + [&](memref::AllocaOp op) { allocsAndAllocas.push_back(op.memref()); }); + + for (auto allocValue : allocsAndAllocas) { Operation *definingOp = allocValue.getDefiningOp(); assert(definingOp && "No defining op"); auto operands = definingOp->getOperands(); @@ -222,6 +227,10 @@ /// The map storing the final placement blocks of a given alloc value. llvm::DenseMap placementBlocks; + + /// The operation that this transformation is working on. It is used to also + /// gather allocas. + Operation *scopeOp; }; /// A state implementation compatible with the `BufferAllocationHoisting` class diff --git a/mlir/test/Transforms/buffer-hoisting.mlir b/mlir/test/Transforms/buffer-hoisting.mlir --- a/mlir/test/Transforms/buffer-hoisting.mlir +++ b/mlir/test/Transforms/buffer-hoisting.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -buffer-hoisting -split-input-file %s | FileCheck %s // This file checks the behaviour of BufferHoisting pass for moving Alloc -// operations to their correct positions. +// and Alloca operations to their correct positions. // Test Case: // bb0 @@ -552,7 +552,7 @@ // ----- -// Test Case: Alloca operations shouldn't be moved. +// Test Case: Alloca operations should also be moved. // CHECK-LABEL: func @condBranchAlloca func @condBranchAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { @@ -568,10 +568,10 @@ return } +// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() // CHECK-NEXT: cond_br // CHECK: ^bb2 // CHECK: ^bb2 -// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() // CHECK-NEXT: test.buffer_based // ----- diff --git a/mlir/test/Transforms/buffer-loop-hoisting.mlir b/mlir/test/Transforms/buffer-loop-hoisting.mlir --- a/mlir/test/Transforms/buffer-loop-hoisting.mlir +++ b/mlir/test/Transforms/buffer-loop-hoisting.mlir @@ -458,3 +458,32 @@ // CHECK-NEXT: {{.*}} = scf.for // CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc({{.*}}) // CHECK-NEXT: {{.*}} = scf.for + +// ----- + +// Test with allocas to ensure that op is also considered. + +// CHECK-LABEL: func @hoist_alloca +func @hoist_alloca( + %lb: index, + %ub: index, + %step: index, + %buf: memref<2xf32>, + %res: memref<2xf32>) { + %0 = memref.alloca() : memref<2xf32> + %1 = scf.for %i = %lb to %ub step %step + iter_args(%iterBuf = %buf) -> memref<2xf32> { + %2 = scf.for %i2 = %lb to %ub step %step + iter_args(%iterBuf2 = %iterBuf) -> memref<2xf32> { + %3 = memref.alloca() : memref<2xf32> + scf.yield %0 : memref<2xf32> + } + scf.yield %0 : memref<2xf32> + } + test.copy(%1, %res) : (memref<2xf32>, memref<2xf32>) + return +} + +// CHECK: %[[ALLOCA0:.*]] = memref.alloca({{.*}}) +// CHECK-NEXT: %[[ALLOCA1:.*]] = memref.alloca({{.*}}) +// CHECK-NEXT: {{.*}} = scf.for