diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -19,6 +19,35 @@ def Transform_EmptyOp : Transform_ConcreteOpType<"tensor.empty">; def Transform_AllocTensorOp : Transform_ConcreteOpType<"bufferization.alloc_tensor">; +//===----------------------------------------------------------------------===// +// BufferLoopHoistingOp +//===----------------------------------------------------------------------===// + +def BufferLoopHoistingOp + : Op, + TransformEachOpTrait, TransformOpInterface]> { + let description = [{ + Hoist buffer allocations ("memref.alloc" and "memref.alloca") from loops + within the targeted op. This transform assumes that there are no buffer + deallocation ops in the IR. + + This transform reads the `target` handle and modifies the payload. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + let assemblyFormat = "$target attr-dict `:` type($target)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + //===----------------------------------------------------------------------===// // OneShotBufferizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h @@ -41,6 +41,10 @@ AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc); +/// Within the given operation, hoist buffers from loops where possible. See +/// "BufferLoopHoistingPass" for more information. +void hoistBuffersFromLoops(Operation *op); + /// Try to eliminate tensor::EmptyOps inside `op` that are anchored on an /// InsertSliceOp, i.e., if it is eventually inserted into another tensor /// (and some other conditions are met). diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -21,6 +21,23 @@ using namespace mlir::bufferization; using namespace mlir::transform; +//===----------------------------------------------------------------------===// +// BufferLoopHoistingOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::BufferLoopHoistingOp::applyToOne( + TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, TransformState &state) { + bufferization::hoistBuffersFromLoops(target); + return DiagnosedSilenceableFailure::success(); +} + +void transform::BufferLoopHoistingOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getTarget(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // OneShotBufferizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" +#include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Operation.h" @@ -389,9 +390,7 @@ void runOnOperation() override { // Hoist all allocations out of loops. - BufferAllocationHoisting optimizer( - getOperation()); - optimizer.hoist(); + hoistBuffersFromLoops(getOperation()); } }; @@ -432,6 +431,11 @@ } // namespace +void mlir::bufferization::hoistBuffersFromLoops(Operation *op) { + BufferAllocationHoisting optimizer(op); + optimizer.hoist(); +} + std::unique_ptr mlir::bufferization::createBufferHoistingPass() { return std::make_unique(); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir @@ -151,3 +151,23 @@ : tensor<5xf32> into tensor<10xf32> return %2 : tensor<10xf32> } + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.bufferization.buffer_loop_hoisting %0 : !transform.any_op +} + +// CHECK-LABEL: func @buffer_loop_hoisting( +// CHECK: memref.alloca +// CHECK: scf.for +// CHECK: memref.store +func.func @buffer_loop_hoisting(%lb: index, %ub: index, %step: index, %f: f32, %pos: index) { + scf.for %iv = %lb to %ub step %step { + %0 = memref.alloca() : memref<5xf32> + memref.store %f, %0[%pos] : memref<5xf32> + } + return +}