diff --git a/mlir/include/mlir/Transforms/BufferUtils.h b/mlir/include/mlir/Transforms/BufferUtils.h --- a/mlir/include/mlir/Transforms/BufferUtils.h +++ b/mlir/include/mlir/Transforms/BufferUtils.h @@ -25,6 +25,27 @@ namespace mlir { +/// A configuration structure for the BufferDeallocation pass. It allows to +/// define custom operations when allocating/cloning/deallocating buffers. +struct BufferDeallocationConfig { + virtual ~BufferDeallocationConfig() = default; + + /// Creates a default BufferDeallocationConfig that uses std::alloc and + /// linalg::copy to clone buffers and std::delloc to free buffers. + static std::shared_ptr createDefault(); + + /// Clones the given source value into a (potentially) new buffer. This allows + /// to define custom create/copy operations for all allocations managed by + /// the BufferDeallocation pass. + virtual Value clone(OpBuilder &builder, Location location, + Value source) const = 0; + + /// Frees the given source buffer. This allows to define custom free + /// operations for all allocations managed by the BufferDeallocation pass. + virtual void free(OpBuilder &builder, Location location, + Value source) const = 0; +}; + /// A simple analysis that detects allocation operations. class BufferPlacementAllocs { public: diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -23,15 +23,23 @@ namespace mlir { class AffineForOp; +struct BufferDeallocationConfig; //===----------------------------------------------------------------------===// // Passes //===----------------------------------------------------------------------===// /// Creates an instance of the BufferDeallocation pass to free all allocated -/// buffers. +/// buffers. This overload uses std::alloc, std::dealloc and linalg::copy to +/// create, clone and free buffers. std::unique_ptr createBufferDeallocationPass(); +/// Creates an instance of the BufferDeallocation pass to free all allocated +/// buffers. This overload uses the given configruation to allocate/clone and +/// free buffers. +std::unique_ptr +createBufferDeallocationPass(std::shared_ptr config); + /// Creates a pass that moves allocations upwards to reduce the number of /// required copies that are inserted during the BufferDeallocation pass. std::unique_ptr createBufferHoistingPass(); diff --git a/mlir/lib/Transforms/BufferDeallocation.cpp b/mlir/lib/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Transforms/BufferDeallocation.cpp @@ -163,8 +163,8 @@ /// introduce copies that in turn leads to additional allocs and de-allocations. class BufferDeallocation : BufferPlacementTransformationBase { public: - BufferDeallocation(Operation *op) - : BufferPlacementTransformationBase(op), dominators(op), + BufferDeallocation(Operation *op, const BufferDeallocationConfig &config) + : BufferPlacementTransformationBase(op), config(config), dominators(op), postDominators(op) {} /// Performs the actual placement/creation of all temporary alloc, copy and @@ -388,32 +388,13 @@ // algorithm. if (copiedValues.contains(sourceValue)) return sourceValue; - // Create a new alloc at the current location of the terminator. - auto memRefType = sourceValue.getType().cast(); + // Create a new clone at the current location of the terminator. OpBuilder builder(terminator); - - // Extract information about dynamically shaped types by - // extracting their dynamic dimensions. - SmallVector dynamicOperands; - for (auto shapeElement : llvm::enumerate(memRefType.getShape())) { - if (!ShapedType::isDynamic(shapeElement.value())) - continue; - dynamicOperands.push_back(builder.create( - terminator->getLoc(), sourceValue, shapeElement.index())); - } - - // TODO: provide a generic interface to create dialect-specific - // Alloc and CopyOp nodes. - auto alloc = builder.create(terminator->getLoc(), memRefType, - dynamicOperands); - - // Create a new copy operation that copies to contents of the old - // allocation to the new one. - builder.create(terminator->getLoc(), sourceValue, alloc); + Value clone = config.clone(builder, terminator->getLoc(), sourceValue); // Remember the copy of original source value. - copiedValues.insert(alloc); - return alloc; + copiedValues.insert(clone); + return clone; } /// Finds correct dealloc positions according to the algorithm described at @@ -478,11 +459,14 @@ continue; // If there is no dealloc node, insert one in the right place. OpBuilder builder(nextOp); - builder.create(alloc.getLoc(), alloc); + config.free(builder, alloc.getLoc(), alloc); } } } + /// The current buffer deallocation configuration to clone/free buffers. + const BufferDeallocationConfig &config; + /// The dominator info to find the appropriate start operation to move the /// allocs. DominanceInfo dominators; @@ -503,6 +487,10 @@ /// into the right positions. Furthermore, it inserts additional allocs and /// copies if necessary. It uses the algorithm described at the top of the file. struct BufferDeallocationPass : BufferDeallocationBase { + std::shared_ptr config; + + BufferDeallocationPass(std::shared_ptr config) + : config(config) {} void runOnFunction() override { // Ensure that there are supported loops only. @@ -514,17 +502,80 @@ } // Place all required temporary alloc, copy and dealloc nodes. - BufferDeallocation deallocation(getFunction()); + BufferDeallocation deallocation(getFunction(), *config.get()); deallocation.deallocate(); } }; +//===----------------------------------------------------------------------===// +// DefaultBufferDeallocationConfig +//===----------------------------------------------------------------------===// + +/// A default implementation of the abstract BufferDeallocationConfig structure +/// that uses std::alloc, linalg::copy and std::dealloc operations. +struct DefaultBufferDeallocationConfig : public BufferDeallocationConfig { + /// Clones the given source value into a new buffer using std::alloc and + /// linalg::copy. + Value clone(OpBuilder &builder, Location location, + Value source) const override { + // Extract information about dynamically shaped types by extracting their + // dynamic dimensions. + auto memRefType = source.getType().cast(); + SmallVector dynamicOperands; + for (auto shapeElement : llvm::enumerate(memRefType.getShape())) { + if (!ShapedType::isDynamic(shapeElement.value())) + continue; + dynamicOperands.push_back( + builder.create(location, source, shapeElement.index())); + } + + // TODO: Once we have a single "std::clone"-like operation that can express + // an allocation and a copy in one single operation, the following piece + // of code needs to be adapted accordingly. This enables support for + // reference-counted allocations, for instance. + + // Create a new std::alloc in this default implementation. + auto alloc = builder.create(location, memRefType, dynamicOperands); + + // Create a new linalg copy operation that copies to contents of the old + // allocation to the new one. + builder.create(location, source, alloc); + + // Return the temporary allocation. + return alloc; + } + + /// Frees the given source buffer by emitting an std::dealloc. + void free(OpBuilder &builder, Location location, + Value source) const override { + // Emit a default std::dealloc to free the source buffer. + builder.create(location, source); + }; +}; + } // end anonymous namespace +//===----------------------------------------------------------------------===// +// BufferDeallocationConfig +//===----------------------------------------------------------------------===// + +/// Creates and returns a new instance of the DefaultBufferDeallocationConfig +/// structure. +std::shared_ptr +mlir::BufferDeallocationConfig::createDefault() { + return std::make_shared(); +} + //===----------------------------------------------------------------------===// // BufferDeallocationPass construction //===----------------------------------------------------------------------===// std::unique_ptr mlir::createBufferDeallocationPass() { - return std::make_unique(); + return createBufferDeallocationPass( + BufferDeallocationConfig::createDefault()); +} + +std::unique_ptr mlir::createBufferDeallocationPass( + std::shared_ptr config) { + return std::make_unique(config); }