diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -48,6 +48,11 @@ unsigned bitwidthOfIndexType = 64, unsigned maxRankOfAllocatedMemRef = 1); +/// Creates a pass that promotes heap-based allocations to stack-based ones. +/// Only buffers smaller with `isSmallAlloc(alloc) == true` are promoted. +std::unique_ptr +createPromoteBuffersToStackPass(std::function isSmallAlloc); + /// Creates a pass that finalizes a partial bufferization by removing remaining /// tensor_load and tensor_to_memref operations. std::unique_ptr createFinalizingBufferizePass(); diff --git a/mlir/lib/Transforms/BufferOptimizations.cpp b/mlir/lib/Transforms/BufferOptimizations.cpp --- a/mlir/lib/Transforms/BufferOptimizations.cpp +++ b/mlir/lib/Transforms/BufferOptimizations.cpp @@ -29,9 +29,9 @@ /// Check if the size of the allocation is less than the given size. The /// transformation is only applied to small buffers since large buffers could /// exceed the stack space. -static bool isSmallAlloc(Value alloc, unsigned maximumSizeInBytes, - unsigned bitwidthOfIndexType, - unsigned maxRankOfAllocatedMemRef) { +static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes, + unsigned bitwidthOfIndexType, + unsigned maxRankOfAllocatedMemRef) { auto type = alloc.getType().dyn_cast(); if (!type || !alloc.getDefiningOp()) return false; @@ -299,8 +299,7 @@ : BufferPlacementTransformationBase(op) {} /// Promote buffers to stack-based allocations. - void promote(unsigned maximumSize, unsigned bitwidthOfIndexType, - unsigned maxRankOfAllocatedMemRef) { + void promote(function_ref isSmallAlloc) { for (BufferPlacementAllocs::AllocEntry &entry : allocs) { Value alloc = std::get<0>(entry); Operation *dealloc = std::get<1>(entry); @@ -308,9 +307,8 @@ // The transformation is done if the allocation is limited to a given // size. Furthermore, a deallocation must not be defined for this // allocation entry and a parent allocation scope must exist. - if (!isSmallAlloc(alloc, maximumSize, bitwidthOfIndexType, - maxRankOfAllocatedMemRef) || - dealloc || !hasAllocationScope(alloc, aliases)) + if (!isSmallAlloc(alloc) || dealloc || + !hasAllocationScope(alloc, aliases)) continue; Operation *startOperation = BufferPlacementAllocs::getStartOperation( @@ -359,9 +357,9 @@ /// The promote buffer to stack pass that tries to convert alloc nodes into /// alloca nodes. -struct PromoteBuffersToStackPass - : PromoteBuffersToStackBase { - +class PromoteBuffersToStackPass + : public PromoteBuffersToStackBase { +public: PromoteBuffersToStackPass(unsigned maxAllocSizeInBytes, unsigned bitwidthOfIndexType, unsigned maxRankOfAllocatedMemRef) { @@ -370,12 +368,24 @@ this->maxRankOfAllocatedMemRef = maxRankOfAllocatedMemRef; } + explicit PromoteBuffersToStackPass(std::function isSmallAlloc) + : isSmallAlloc(std::move(isSmallAlloc)) {} + void runOnFunction() override { // Move all allocation nodes and convert candidates into allocas. BufferPlacementPromotion optimizer(getFunction()); - optimizer.promote(this->maxAllocSizeInBytes, this->bitwidthOfIndexType, - this->maxRankOfAllocatedMemRef); + if (isSmallAlloc == nullptr) { + isSmallAlloc = [=](Value alloc) { + return defaultIsSmallAlloc(alloc, maxAllocSizeInBytes, + bitwidthOfIndexType, + maxRankOfAllocatedMemRef); + }; + } + optimizer.promote(isSmallAlloc); } + +private: + std::function isSmallAlloc; }; } // end anonymous namespace @@ -395,3 +405,8 @@ return std::make_unique( maxAllocSizeInBytes, bitwidthOfIndexType, maxRankOfAllocatedMemRef); } + +std::unique_ptr +mlir::createPromoteBuffersToStackPass(std::function isSmallAlloc) { + return std::make_unique(std::move(isSmallAlloc)); +}