diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -495,8 +495,9 @@ LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to, const BufferizationOptions &options); -/// Finalize all buffer allocations, i.e., create alloc ops as specified in the -/// bufferization options and deallocate all buffers. +/// Finalize all buffer allocations. +/// * Hoist buffer allocations as much as possible. +/// * Create alloc/dealloc ops as specified by the bufferization options. LogicalResult finalizeBuffers(Operation *op, const BufferizationOptions &options); } // namespace bufferization diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -387,40 +387,9 @@ return success(); } -/// Move the insertion point of the given builder to the beginning of a -/// surrounding block as much as possible, while not crossing any allocation -/// hoisting barriers. -static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) { - Operation *op = b.getInsertionBlock()->getParentOp(); - while (op) { - if (auto bufferizableOp = dyn_cast(op)) - if (bufferizableOp.isAllocationHoistingBarrier()) - break; - op = op->getParentOp(); - } - - if (!op) { - // No allocation hoisting barrier found. Hoist to FuncOp. - op = b.getInsertionBlock()->getParentOp(); - if (!isa(op)) - op = op->getParentOfType(); - assert(op && "could not find enclosing FuncOp"); - } - - // TODO: Handle cases where allocation hoisting barrier has more than one - // region or block. - assert(op->getNumRegions() == 1 && - "allocation hoisting barriers with >1 regions not supported"); - assert(op->getRegion(0).getBlocks().size() == 1 && - "allocation hoisting barriers with >1 blocks not supported"); - b.setInsertionPointToStart(&(op->getRegion(0).front())); -} - /// Compute the type of the `memref` to use for allocating the buffer for /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the -/// dynamic dimensions in the returned `memref` type. The function may also set -/// the insertion point to an earlier location, where the allocation should -/// happen ("allocation hoisting"). +/// dynamic dimensions in the returned `memref` type. static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, Value shapedValue, SmallVectorImpl &dynShape) { @@ -453,15 +422,6 @@ } } - // If the buffer is statically shaped, try to hoist it to the first enclosing - // parallel region. - // TODO: also hoist in the dynamic case. For now this relies on subsequent - // calls to LICM and buffer hoisting which will most likely not succeed. - // TODO: when packing, allocate a static bounding box which will enable more - // hoisting. - if (dynShape.empty()) - moveInsertionPointToAllocationHoistingBarrier(b); - return allocMemRefType; } @@ -481,7 +441,6 @@ assert(shapedValue.getType().isa()); MemRefType memRefType = shapedValue.getType().dyn_cast(); SmallVector dynShape; - // Note: getAllocationTypeAndShape also sets the insertion point. MemRefType allocMemRefType = getAllocationTypeAndShape(b, loc, shapedValue, dynShape); Value alloc = createBufferAllocation(b, loc, allocMemRefType, dynShape); @@ -511,9 +470,8 @@ return success(); } -LogicalResult -bufferization::finalizeBuffers(Operation *op, - const BufferizationOptions &options) { +static LogicalResult +createAllocDeallocOps(Operation *op, const BufferizationOptions &options) { IRRewriter rewriter(op->getContext()); // Bufferization creates memref.alloca ops. After bufferization, these must be @@ -546,6 +504,63 @@ return success(!status.wasInterrupted()); } +/// Try to hoist all new buffer allocations until the next hoisting barrier. +// TODO: Consolidate this function with the existing buffer hoisting pass. +static LogicalResult +hoistBufferAllocations(Operation *op, const BufferizationOptions &options) { + // Gather all buffer allocations that were created by the bufferization. + SmallVector allocaOps; + op->walk([&](memref::AllocaOp allocaOp) { + if (allocaOp->hasAttr(kBufferAllocationAttr)) + allocaOps.push_back(allocaOp); + }); + + for (Operation *allocaOp : allocaOps) { + // TODO: Hoisting of allocs with dynamic shape not implemented. + if (!allocaOp->getOpOperands().empty()) + continue; + + Operation *op = allocaOp->getBlock()->getParentOp(); + while (op) { + if (auto bufferizableOp = dyn_cast(op)) + if (bufferizableOp.isAllocationHoistingBarrier()) + break; + op = op->getParentOp(); + } + + if (!op) { + // No allocation hoisting barrier found. Hoist to FuncOp. + op = allocaOp->getBlock()->getParentOp(); + if (!isa(op)) + op = op->getParentOfType(); + assert(op && "could not find enclosing FuncOp"); + } + + // TODO: Handle cases where allocation hoisting barrier has more than one + // region or block. + assert(op->getNumRegions() == 1 && + "allocation hoisting barriers with >1 regions not supported"); + assert(op->getRegion(0).getBlocks().size() == 1 && + "allocation hoisting barriers with >1 blocks not supported"); + Block *insertionBlock = &(op->getRegion(0).front()); + // Move to the beginning of the block. + allocaOp->moveBefore(&insertionBlock->front()); + } + + return success(); +} + +LogicalResult +bufferization::finalizeBuffers(Operation *op, + const BufferizationOptions &options) { + if (failed(hoistBufferAllocations(op, options))) + return failure(); + if (failed(createAllocDeallocOps(op, options))) + return failure(); + + return success(); +} + //===----------------------------------------------------------------------===// // Bufferization-specific BlockAndValueMapping support with debugging. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -71,8 +71,8 @@ #map0 = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @multiple_results -// CHECK: %[[RESULT1:.*]] = memref.alloc() {{.*}} : memref<4xf32> // CHECK: %[[RESULT0:.*]] = memref.alloc() {{.*}} : memref<4xf32> +// CHECK: %[[RESULT1:.*]] = memref.alloc() {{.*}} : memref<4xf32> // CHECK: linalg.generic // CHECK-SAME: ins(%{{.*}} : memref<4xf32>) // CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref<4xf32>, memref<4xf32>) diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -89,9 +89,9 @@ // CHECK-LABEL: func @tensor.from_elements_1d( // CHECK-SAME: %[[ELEM0:.*]]: index, // CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> { -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<2xindex> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<2xindex> // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]] // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]] // CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] @@ -107,7 +107,7 @@ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex> +// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex> // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]] // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]] // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]] @@ -141,7 +141,7 @@ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2x2xf32> +// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2x2xf32> // CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]] // CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]]