diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -113,7 +113,8 @@ // AllocOp //===----------------------------------------------------------------------===// -def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource> { +def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, + [DeclareOpInterfaceMethods]> { let summary = "memory allocation operation"; let description = [{ The `alloc` operation allocates a region of memory, as specified by its @@ -409,7 +410,8 @@ def CloneOp : MemRef_Op<"clone", [ CopyOpInterface, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods ]> { let builders = [ OpBuilder<(ins "Value":$value), [{ diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td --- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td @@ -16,6 +16,37 @@ include "mlir/Interfaces/SideEffectInterfaceBase.td" +//===----------------------------------------------------------------------===// +// AllocationOpInterface +//===----------------------------------------------------------------------===// + +def AllocationOpInterface : OpInterface<"AllocationOpInterface"> { + let description = [{ + This interface provides the ability to construct associated deallocation + operations that are compatible with the current allocation operation. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + StaticInterfaceMethod<[{ + Builds a deallocation operation using the provided builder and the + current allocation value (which refers to the current Op implementing + this interface. + }], + "::mlir::Operation*", "buildDealloc", + (ins "::mlir::OpBuilder&":$opBuilder, "::mlir::Value":$alloc) + >, + StaticInterfaceMethod<[{ + Builds a clone operation using the provided builder and the current + allocation value (which refers to the current Op implementing this + interface. + }], + "::mlir::Value", "buildClone", + (ins "::mlir::OpBuilder&":$opBuilder, "::mlir::Value":$alloc) + > + ]; +} + //===----------------------------------------------------------------------===// // MemoryEffects //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -190,6 +190,14 @@ }; } // end anonymous namespace. +Operation *AllocOp::buildDealloc(OpBuilder &builder, Value alloc) { + return builder.create(alloc.getLoc(), alloc); +} + +Value AllocOp::buildClone(OpBuilder &builder, Value alloc) { + return builder.create(alloc.getLoc(), alloc); +} + void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, SimplifyDeadAlloc>(context); @@ -638,6 +646,14 @@ return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); } +Operation *CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { + return builder.create(alloc.getLoc(), alloc); +} + +Value CloneOp::buildClone(OpBuilder &builder, Value alloc) { + return builder.create(alloc.getLoc(), alloc); +} + //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/BufferDeallocation.cpp b/mlir/lib/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Transforms/BufferDeallocation.cpp @@ -187,12 +187,50 @@ /// The buffer deallocation transformation which ensures that all allocs in the /// program have a corresponding de-allocation. As a side-effect, it might also /// introduce clones that in turn leads to additional deallocations. -class BufferDeallocation : BufferPlacementTransformationBase { +class BufferDeallocation : public BufferPlacementTransformationBase { public: + using AliasAllocationMapT = llvm::DenseMap; + BufferDeallocation(Operation *op) : BufferPlacementTransformationBase(op), dominators(op), postDominators(op) {} + /// Checks if all allocation operations either provide an already existing + /// deallocation operation or implement the AllocationOpInterface. In + /// addition, this method initializes the internal alias to + /// AllocationOpInterface mapping in order to get compatible + /// AllocationOpInterface implementations for aliases. + LogicalResult prepare() { + for (const BufferPlacementAllocs::AllocEntry &entry : allocs) { + // Get the defining allocation operation. + Value alloc = std::get<0>(entry); + Operation *definingOp = alloc.getDefiningOp(); + assert( + definingOp && + "DefiningOp must exist in the case of an explicit allocation entry"); + // If there is no existing deallocation operation and no implementation of + // the AllocationOpInterface, we cannot apply the BufferDeallocation pass. + if (!std::get<1>(entry) && !isa(definingOp)) { + return definingOp->emitError( + "Allocation is not deallocated explicitly nor does the operation " + "implement the AllocationOpInterface."); + } + + // Register the current allocation interface implementation. + auto allocationInterface = cast(definingOp); + aliasToAllocations[alloc] = allocationInterface; + + // Get the alias information for the current allocation node. + llvm::for_each(aliases.resolve(alloc), [&](Value alias) { + // TODO: check for incompatible implementations of the + // AllocationOpInterface. This could be realized by promoting the + // AllocationOpInterface to a DialectInterface. + aliasToAllocations[alias] = allocationInterface; + }); + } + return LogicalResult::success(); + } + /// Performs the actual placement/creation of all temporary clone and dealloc /// nodes. void deallocate() { @@ -423,8 +461,7 @@ // Create a new clone operation that copies the contents of the old // buffer to the new one. OpBuilder builder(terminator); - auto cloneOp = - builder.create(terminator->getLoc(), sourceValue); + Value cloneOp = buildClone(builder, sourceValue); // Remember the clone of original source value. clonedValues.insert(cloneOp); @@ -434,7 +471,7 @@ /// Finds correct dealloc positions according to the algorithm described at /// the top of the file for all alloc nodes and block arguments that can be /// handled by this analysis. - void placeDeallocs() const { + void placeDeallocs() { // Move or insert deallocs using the previously computed information. // These deallocations will be linked to their associated allocation nodes // since they don't have any aliases that can (potentially) increase their @@ -493,11 +530,66 @@ continue; // If there is no dealloc node, insert one in the right place. OpBuilder builder(nextOp); - builder.create(alloc.getLoc(), alloc); + // Get the allocation op interface and build the associated + // deallocation operation. + buildDealloc(builder, alloc); } } } + /// Applies the interfaceFn method to the interface implementation associated + /// with the given value. If there is no registered interface implementation, + /// this method calls the defaultFn to build a default dealloc/clone + /// operation. + template + auto applyAllocationInterface(Value value, const InterfaceFnT &interfaceFn, + const DefaultFnT &defaultFn) { + auto it = aliasToAllocations.find(value); + if (it != aliasToAllocations.end()) + return interfaceFn(it->second); + // Build a "default" DeallocOp for unknown allocation sources. This can + // occur in the case of function parameters. + return defaultFn(); + } + + /// Builds a deallocation operation compatible with the given allocation + /// value. If there is no registered AllocationOpInterface implementation for + /// the given value (e.g. in the case of a function parameter), this method + /// builds a memref::DeallocOp. + Operation *buildDealloc(OpBuilder &builder, Value alloc) { + return applyAllocationInterface( + alloc, + [&](AllocationOpInterface interface) { + // Call the allocation op interface to build a supported and + // compatible deallocation operation. + return interface.buildDealloc(builder, alloc); + }, + [&]() { + // Build a "default" DeallocOp for unknown allocation sources. + return builder.create(alloc.getLoc(), alloc) + .getOperation(); + }); + } + + /// Builds a clone operation compatible with the given allocation value. If + /// there is no registered AllocationOpInterface implementation for the given + /// value (e.g. in the case of a function parameter), this method builds a + /// memref::CloneOp. + Value buildClone(OpBuilder &builder, Value alloc) { + return applyAllocationInterface( + alloc, + [&](AllocationOpInterface interface) { + // Call the allocation op interface to build a supported and + // compatible clone operation. + return interface.buildClone(builder, alloc); + }, + [&]() { + // Build a "default" CloneOp for unknown allocation sources. + return builder.create(alloc.getLoc(), alloc) + .getResult(); + }); + } + /// The dominator info to find the appropriate start operation to move the /// allocs. DominanceInfo dominators; @@ -508,6 +600,9 @@ /// Stores already cloned buffers to avoid additional clones of clones. ValueSetT clonedValues; + + /// Maps aliases to their source allocation interfaces (inverse mapping). + AliasAllocationMapT aliasToAllocations; }; //===----------------------------------------------------------------------===// @@ -529,12 +624,18 @@ } // Check that the control flow structures are supported. - if (!validateSupportedControlFlow(func.getRegion())) { + if (!validateSupportedControlFlow(func.getRegion())) return signalPassFailure(); - } - // Place all required temporary clone and dealloc nodes. + // Gather all required allocation nodes and prepare the deallocation phase. BufferDeallocation deallocation(func); + + // Check for supported AllocationOpInterface implementations and prepare the + // internal deallocation pass. + if (deallocation.prepare().failed()) + return signalPassFailure(); + + // Place all required temporary clone and dealloc nodes. deallocation.deallocate(); } };