diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -113,7 +113,8 @@ // AllocOp //===----------------------------------------------------------------------===// -def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource> { +def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, + [DeclareOpInterfaceMethods]> { let summary = "memory allocation operation"; let description = [{ The `alloc` operation allocates a region of memory, as specified by its @@ -409,7 +410,8 @@ def CloneOp : MemRef_Op<"clone", [ CopyOpInterface, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods ]> { let builders = [ OpBuilder<(ins "Value":$value), [{ @@ -437,6 +439,7 @@ let extraClassDeclaration = [{ Value getSource() { return input(); } Value getTarget() { return output(); } + static Operation* buildCloneOpDealloc(OpBuilder &builder, Value alloc); }]; let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td --- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td @@ -16,6 +16,29 @@ include "mlir/Interfaces/SideEffectInterfaceBase.td" +//===----------------------------------------------------------------------===// +// AllocationOpInterface +//===----------------------------------------------------------------------===// + +def AllocationOpInterface : OpInterface<"AllocationOpInterface"> { + let description = [{ + This interface provides the ability to construct associated deallocation + operations that are compatible with the current allocation operation. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Builds a deallocation operation using the provided builder and the + current allocation value (which refers to the current Op implementing + this interface. + }], + "::mlir::Operation*", "buildDealloc", + (ins "::mlir::OpBuilder&":$opBuilder, "::mlir::Value":$alloc) + > + ]; +} + //===----------------------------------------------------------------------===// // MemoryEffects //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/BufferUtils.h b/mlir/include/mlir/Transforms/BufferUtils.h --- a/mlir/include/mlir/Transforms/BufferUtils.h +++ b/mlir/include/mlir/Transforms/BufferUtils.h @@ -104,6 +104,9 @@ /// Constructs a new operation base using the given root operation. BufferPlacementTransformationBase(Operation *op); + /// Returns the underlying list of allocation operations. + const BufferPlacementAllocs &getAllocs() const { return allocs; } + protected: /// Alias information that can be updated during the insertion of copies. BufferViewFlowAnalysis aliases; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -190,6 +190,10 @@ }; } // end anonymous namespace. +Operation *AllocOp::buildDealloc(OpBuilder &builder, Value alloc) { + return builder.create(alloc.getLoc(), alloc); +} + void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, SimplifyDeadAlloc>(context); @@ -638,6 +642,17 @@ return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); } +/// Builds a default DeallocOp for all CloneOps. +Operation *CloneOp::buildCloneOpDealloc(OpBuilder &builder, Value alloc) { + return builder.create(alloc.getLoc(), alloc); +} + +/// Falls back to the static buildCloneOpDealloc function to build a default +/// DeallocOp for all CloneOps. +Operation *CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { + return buildCloneOpDealloc(builder, alloc); +} + //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/BufferDeallocation.cpp b/mlir/lib/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Transforms/BufferDeallocation.cpp @@ -99,6 +99,27 @@ return success; } +/// Checks if all allocation operations either provide an already existing +/// deallocation operation or implement the AllocationOpInterface. +static bool validateSupportedAllocations(const BufferPlacementAllocs &allocs) { + bool success = true; + for (const BufferPlacementAllocs::AllocEntry &entry : allocs) { + // Get the defining allocation operation. + Operation *definingOp = std::get<0>(entry).getDefiningOp(); + assert(definingOp && + "DefiningOp must exist in the case of an explicit allocation entry"); + // If there is no existing deallocation operation and no implementation of + // the AllocationOpInterface, we cannot apply the BufferDeallocation pass. + if (!std::get<1>(entry) && !dyn_cast(definingOp)) { + definingOp->emitError( + "All allocations must either be deallocated explicitly or need to " + "implement the AllocationOpInterface."); + success = false; + } + } + return success; +} + namespace { //===----------------------------------------------------------------------===// @@ -187,7 +208,7 @@ /// The buffer deallocation transformation which ensures that all allocs in the /// program have a corresponding de-allocation. As a side-effect, it might also /// introduce clones that in turn leads to additional deallocations. -class BufferDeallocation : BufferPlacementTransformationBase { +class BufferDeallocation : public BufferPlacementTransformationBase { public: BufferDeallocation(Operation *op) : BufferPlacementTransformationBase(op), dominators(op), @@ -493,7 +514,18 @@ continue; // If there is no dealloc node, insert one in the right place. OpBuilder builder(nextOp); - builder.create(alloc.getLoc(), alloc); + // First, find out which type of deallocation operation we must build. + Operation *definingOp = alloc.getDefiningOp(); + if (!definingOp || !isa(definingOp)) { + // Insert a conservative deallocation operation in the case of a + // block argument or another alias, as we know that this argument or + // alias will contain a copy introduced by a CloneOp. + memref::CloneOp::buildCloneOpDealloc(builder, alloc); + } else { + // Get the allocation op interface and build the associated + // deallocation operation. + cast(definingOp).buildDealloc(builder, alloc); + } } } } @@ -529,12 +561,17 @@ } // Check that the control flow structures are supported. - if (!validateSupportedControlFlow(func.getRegion())) { + if (!validateSupportedControlFlow(func.getRegion())) return signalPassFailure(); - } - // Place all required temporary clone and dealloc nodes. + // Gather all required allocation nodes and prepare the deallocation phase. BufferDeallocation deallocation(func); + + // Check for support AllocationOpInterface implementations. + if (!validateSupportedAllocations(deallocation.getAllocs())) + return signalPassFailure(); + + // Place all required temporary clone and dealloc nodes. deallocation.deallocate(); } };