diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.h @@ -15,6 +15,16 @@ #include "mlir/IR/Builders.h" +namespace mlir { +// Enum class representing different hoisting kinds for the allocation +// operation +enum class HoistingKind : std::uint8_t { + None = 0, // No hoisting kind selected + Loop = 1 << 0, // Indicates loop hoisting kind + Block = 1 << 1 // Indicates dominated block hoisting kind +}; +} // namespace mlir + #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h.inc" #endif // MLIR_DIALECT_BUFFERIZATION_IR_ALLOCATIONOPINTERFACE_H_ diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.td @@ -51,6 +51,25 @@ "::std::optional<::mlir::Value>", "buildClone", (ins "::mlir::OpBuilder&":$builder, "::mlir::Value":$alloc), [{}], /*defaultImplementation=*/[{ return std::nullopt; }] + >, + StaticInterfaceMethod<[{ + Returns the kind of hoisting supported for the buffer allocated by this + operation. + }], + "::mlir::HoistingKind", "getHoistingKind", + (ins), [{}], + /*defaultImplementation=*/[{ return HoistingKind::None; }] + >, + StaticInterfaceMethod<[{ + Builds a stack allocation operation using the provided builder and the + current allocation value (which refers to the current Op implementing this + interface). The allocation value is a result of the current + operation implementing this interface. If there is no compatible + stack allocation operation, this method can return ::std::nullopt. + }], + "::std::optional<::mlir::Operation*>", "buildPromotedAlloc", + (ins "::mlir::OpBuilder&":$builder, "::mlir::Value":$alloc), [{}], + /*defaultImplementation=*/[{ return std::nullopt; }] > ]; } diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -175,4 +176,5 @@ void mlir::bufferization::registerTransformDialectExtension( DialectRegistry ®istry) { registry.addExtensions(); + bufferization::registerAllocationOpInterfaceExternalModels(registry); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -637,6 +637,23 @@ return builder.create(alloc.getLoc(), alloc) .getResult(); } + static ::mlir::HoistingKind getHoistingKind() { + return static_cast(static_cast(HoistingKind::Loop) | + static_cast(HoistingKind::Block)); + } + static ::std::optional<::mlir::Operation *> + buildPromotedAlloc(OpBuilder &builder, Value alloc) { + Operation *definingOp = alloc.getDefiningOp(); + return builder.create( + definingOp->getLoc(), cast(definingOp->getResultTypes()[0]), + definingOp->getOperands(), definingOp->getAttrs()); + } +}; + +struct DefaultAutomaticAllocationHoistingInterface + : public bufferization::AllocationOpInterface::ExternalModel< + DefaultAutomaticAllocationHoistingInterface, memref::AllocaOp> { + static ::mlir::HoistingKind getHoistingKind() { return HoistingKind::Loop; } }; struct DefaultReallocationInterface @@ -713,6 +730,8 @@ DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { memref::AllocOp::attachInterface(*ctx); + memref::AllocaOp::attachInterface< + DefaultAutomaticAllocationHoistingInterface>(*ctx); memref::ReallocOp::attachInterface(*ctx); }); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -39,6 +40,22 @@ return isa(op); } +/// Returns true if the given operation implements the AllocationOpInterface +/// and it supports the dominate block hoisting. +static bool allowAllocDominateBlockHoisting(Operation *op) { + auto allocOp = dyn_cast(op); + return allocOp && (static_cast(allocOp.getHoistingKind()) & + static_cast(HoistingKind::Block)); +} + +/// Returns true if the given operation implements the AllocationOpInterface +/// and it supports the loop hoisting. +static bool allowAllocLoopHoisting(Operation *op) { + auto allocOp = dyn_cast(op); + return allocOp && (static_cast(allocOp.getHoistingKind()) & + static_cast(HoistingKind::Loop)); +} + /// Check if the size of the allocation is less than the given size. The /// transformation is only applied to small buffers since large buffers could /// exceed the stack space. @@ -278,7 +295,7 @@ /// Returns true if the given operation should be considered for hoisting. static bool shouldHoistOpType(Operation *op) { - return llvm::isa(op); + return allowAllocDominateBlockHoisting(op); } /// Sets the current placement block to the given block. @@ -315,7 +332,7 @@ /// Returns true if the given operation should be considered for hoisting. static bool shouldHoistOpType(Operation *op) { - return llvm::isa(op); + return allowAllocLoopHoisting(op); } /// Does not change the internal placement block, as we want to move @@ -355,13 +372,15 @@ // `AutomaticAllocationScope` determined during the initialization phase. OpBuilder builder(startOperation); Operation *allocOp = alloc.getDefiningOp(); - Operation *alloca = builder.create( - alloc.getLoc(), cast(alloc.getType()), - allocOp->getOperands(), allocOp->getAttrs()); - - // Replace the original alloc by a newly created alloca. - allocOp->replaceAllUsesWith(alloca); - allocOp->erase(); + if (auto allocInterface = dyn_cast(allocOp)) { + Operation *alloca = + allocInterface.buildPromotedAlloc(builder, alloc).value(); + if (!alloca) + continue; + // Replace the original alloc by a newly created alloca. + allocOp->replaceAllUsesWith(alloca); + allocOp->erase(); + } } } };