diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -119,7 +119,10 @@ // AllocOp //===----------------------------------------------------------------------===// -def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource> { +def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, [ + DeclareOpInterfaceMethods] + > { let summary = "memory allocation operation"; let description = [{ The `alloc` operation allocates a region of memory, as specified by its @@ -415,7 +418,9 @@ def CloneOp : MemRef_Op<"clone", [ CopyOpInterface, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods ]> { let builders = [ OpBuilder<(ins "Value":$value), [{ diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td --- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td @@ -16,6 +16,45 @@ include "mlir/Interfaces/SideEffectInterfaceBase.td" +//===----------------------------------------------------------------------===// +// AllocationOpInterface +//===----------------------------------------------------------------------===// + +def AllocationOpInterface : OpInterface<"AllocationOpInterface"> { + let description = [{ + This interface provides general allocation-related methods that are + designed for allocation operations. For example, it offers the ability to + construct associated deallocation and clone operations that are compatible + with the current allocation operation. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + StaticInterfaceMethod<[{ + Builds a deallocation operation using the provided builder and the + current allocation value (which refers to the current Op implementing + this interface). The allocation value is a result of the current + operation implementing this interface. If there is no compatible + deallocation operation, this method can return ::llvm::None. + }], + "::mlir::Optional<::mlir::Operation*>", "buildDealloc", + (ins "::mlir::OpBuilder&":$opBuilder, "::mlir::Value":$alloc), [{}], + /*defaultImplementation=*/[{ return llvm::None; }] + >, + StaticInterfaceMethod<[{ + Builds a clone operation using the provided builder and the current + allocation value (which refers to the current Op implementing this + interface). The allocation value is a result of the current operation + implementing this interface. If there is no compatible clone operation, + this method can return ::llvm::None. + }], + "::mlir::Optional<::mlir::Value>", "buildClone", + (ins "::mlir::OpBuilder&":$opBuilder, "::mlir::Value":$alloc), [{}], + /*defaultImplementation=*/[{ return llvm::None; }] + > + ]; +} + //===----------------------------------------------------------------------===// // MemoryEffects //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -190,6 +190,15 @@ }; } // end anonymous namespace. +Optional AllocOp::buildDealloc(OpBuilder &builder, Value alloc) { + return builder.create(alloc.getLoc(), alloc) + .getOperation(); +} + +Optional AllocOp::buildClone(OpBuilder &builder, Value alloc) { + return builder.create(alloc.getLoc(), alloc).getResult(); +} + void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, SimplifyDeadAlloc>(context); @@ -638,6 +647,15 @@ return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); } +Optional CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { + return builder.create(alloc.getLoc(), alloc) + .getOperation(); +} + +Optional CloneOp::buildClone(OpBuilder &builder, Value alloc) { + return builder.create(alloc.getLoc(), alloc).getResult(); +} + //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/BufferDeallocation.cpp b/mlir/lib/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Transforms/BufferDeallocation.cpp @@ -64,14 +64,18 @@ using namespace mlir; /// Walks over all immediate return-like terminators in the given region. -static void walkReturnOperations(Region *region, - std::function func) { +static LogicalResult +walkReturnOperations(Region *region, + std::function func) { for (Block &block : *region) { Operation *terminator = block.getTerminator(); // Skip non region-return-like terminators. - if (isRegionReturnLike(terminator)) - func(terminator); + if (isRegionReturnLike(terminator)) { + if (failed(func(terminator))) + return failure(); + } } + return success(); } /// Checks if all operations in a given region that have at least one attached @@ -187,24 +191,60 @@ /// The buffer deallocation transformation which ensures that all allocs in the /// program have a corresponding de-allocation. As a side-effect, it might also /// introduce clones that in turn leads to additional deallocations. -class BufferDeallocation : BufferPlacementTransformationBase { +class BufferDeallocation : public BufferPlacementTransformationBase { public: + using AliasAllocationMapT = llvm::DenseMap; + BufferDeallocation(Operation *op) : BufferPlacementTransformationBase(op), dominators(op), postDominators(op) {} + /// Checks if all allocation operations either provide an already existing + /// deallocation operation or implement the AllocationOpInterface. In + /// addition, this method initializes the internal alias to + /// AllocationOpInterface mapping in order to get compatible + /// AllocationOpInterface implementations for aliases. + LogicalResult prepare() { + for (const BufferPlacementAllocs::AllocEntry &entry : allocs) { + // Get the defining allocation operation. + Value alloc = std::get<0>(entry); + auto allocationInterface = alloc.getDefiningOp(); + // If there is no existing deallocation operation and no implementation of + // the AllocationOpInterface, we cannot apply the BufferDeallocation pass. + if (!std::get<1>(entry) && !allocationInterface) { + return alloc.getDefiningOp()->emitError( + "Allocation is not deallocated explicitly nor does the operation " + "implement the AllocationOpInterface."); + } + + // Register the current allocation interface implementation. + aliasToAllocations[alloc] = allocationInterface; + + // Get the alias information for the current allocation node. + llvm::for_each(aliases.resolve(alloc), [&](Value alias) { + // TODO: check for incompatible implementations of the + // AllocationOpInterface. This could be realized by promoting the + // AllocationOpInterface to a DialectInterface. + aliasToAllocations[alias] = allocationInterface; + }); + } + return success(); + } + /// Performs the actual placement/creation of all temporary clone and dealloc /// nodes. - void deallocate() { + LogicalResult deallocate() { // Add additional clones that are required. - introduceClones(); + if (failed(introduceClones())) + return failure(); + // Place deallocations for all allocation entries. - placeDeallocs(); + return placeDeallocs(); } private: /// Introduces required clone operations to avoid memory leaks. - void introduceClones() { + LogicalResult introduceClones() { // Initialize the set of values that require a dedicated memory free // operation since their operands cannot be safely deallocated in a post // dominator. @@ -256,21 +296,22 @@ // Add new allocs and additional clone operations. for (Value value : valuesToFree) { - if (auto blockArg = value.dyn_cast()) - introduceBlockArgCopy(blockArg); - else - introduceValueCopyForRegionResult(value); + if (failed(value.isa() + ? introduceBlockArgCopy(value.cast()) + : introduceValueCopyForRegionResult(value))) + return failure(); // Register the value to require a final dealloc. Note that we do not have // to assign a block here since we do not want to move the allocation node // to another location. allocs.registerAlloc(std::make_tuple(value, nullptr)); } + return success(); } /// Introduces temporary clones in all predecessors and copies the source /// values into the newly allocated buffers. - void introduceBlockArgCopy(BlockArgument blockArg) { + LogicalResult introduceBlockArgCopy(BlockArgument blockArg) { // Allocate a buffer for the current block argument in the block of // the associated value (which will be a predecessor block by // definition). @@ -284,18 +325,21 @@ Value sourceValue = branchInterface.getSuccessorOperands(it.getSuccessorIndex()) .getValue()[blockArg.getArgNumber()]; - // Create a new clone at the current location of the terminator. - Value clone = introduceCloneBuffers(sourceValue, terminator); // Wire new clone and successor operand. auto mutableOperands = branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex()); - if (!mutableOperands.hasValue()) + if (!mutableOperands) { terminator->emitError() << "terminators with immutable successor " "operands are not supported"; - else - mutableOperands.getValue() - .slice(blockArg.getArgNumber(), 1) - .assign(clone); + continue; + } + // Create a new clone at the current location of the terminator. + auto clone = introduceCloneBuffers(sourceValue, terminator); + if (failed(clone)) + return failure(); + mutableOperands.getValue() + .slice(blockArg.getArgNumber(), 1) + .assign(*clone); } // Check whether the block argument has implicitly defined predecessors via @@ -307,14 +351,15 @@ RegionBranchOpInterface regionInterface; if (!argRegion || &argRegion->front() != block || !(regionInterface = dyn_cast(parentOp))) - return; + return success(); - introduceClonesForRegionSuccessors( - regionInterface, argRegion->getParentOp()->getRegions(), blockArg, - [&](RegionSuccessor &successorRegion) { - // Find a predecessor of our argRegion. - return successorRegion.getSuccessor() == argRegion; - }); + if (failed(introduceClonesForRegionSuccessors( + regionInterface, argRegion->getParentOp()->getRegions(), blockArg, + [&](RegionSuccessor &successorRegion) { + // Find a predecessor of our argRegion. + return successorRegion.getSuccessor() == argRegion; + }))) + return failure(); // Check whether the block argument belongs to an entry region of the // parent operation. In this case, we have to introduce an additional clone @@ -326,24 +371,27 @@ return successorRegion.getSuccessor() == argRegion; }); if (it == successorRegions.end()) - return; + return success(); // Determine the actual operand to introduce a clone for and rewire the // operand to point to the clone instead. Value operand = regionInterface.getSuccessorEntryOperands(argRegion->getRegionNumber()) [llvm::find(it->getSuccessorInputs(), blockArg).getIndex()]; - Value clone = introduceCloneBuffers(operand, parentOp); + auto clone = introduceCloneBuffers(operand, parentOp); + if (failed(clone)) + return failure(); auto op = llvm::find(parentOp->getOperands(), operand); assert(op != parentOp->getOperands().end() && "parentOp does not contain operand"); - parentOp->setOperand(op.getIndex(), clone); + parentOp->setOperand(op.getIndex(), *clone); + return success(); } /// Introduces temporary clones in front of all associated nested-region /// terminators and copies the source values into the newly allocated buffers. - void introduceValueCopyForRegionResult(Value value) { + LogicalResult introduceValueCopyForRegionResult(Value value) { // Get the actual result index in the scope of the parent terminator. Operation *operation = value.getDefiningOp(); auto regionInterface = cast(operation); @@ -358,15 +406,15 @@ // been considered critical. Therefore, the algorithm assumes that a clone // of a previously allocated buffer is returned by the operation (like in // the case of a block argument). - introduceClonesForRegionSuccessors(regionInterface, operation->getRegions(), - value, regionPredicate); + return introduceClonesForRegionSuccessors( + regionInterface, operation->getRegions(), value, regionPredicate); } /// Introduces buffer clones for all terminators in the given regions. The /// regionPredicate is applied to every successor region in order to restrict /// the clones to specific regions. template - void introduceClonesForRegionSuccessors( + LogicalResult introduceClonesForRegionSuccessors( RegionBranchOpInterface regionInterface, MutableArrayRef regions, Value argValue, const TPredicate ®ionPredicate) { for (Region ®ion : regions) { @@ -389,27 +437,33 @@ // Iterate over all immediate terminator operations to introduce // new buffer allocations. Thereby, the appropriate terminator operand // will be adjusted to point to the newly allocated buffer instead. - walkReturnOperations(®ion, [&](Operation *terminator) { - // Get the actual mutable operands for this terminator op. - auto terminatorOperands = *getMutableRegionBranchSuccessorOperands( - terminator, region.getRegionNumber()); - // Extract the source value from the current terminator. - // This conversion needs to exist on a separate line due to a bug in - // GCC conversion analysis. - OperandRange immutableTerminatorOperands = terminatorOperands; - Value sourceValue = immutableTerminatorOperands[operandIndex]; - // Create a new clone at the current location of the terminator. - Value clone = introduceCloneBuffers(sourceValue, terminator); - // Wire clone and terminator operand. - terminatorOperands.slice(operandIndex, 1).assign(clone); - }); + if (failed(walkReturnOperations(®ion, [&](Operation *terminator) { + // Get the actual mutable operands for this terminator op. + auto terminatorOperands = *getMutableRegionBranchSuccessorOperands( + terminator, region.getRegionNumber()); + // Extract the source value from the current terminator. + // This conversion needs to exist on a separate line due to a bug in + // GCC conversion analysis. + OperandRange immutableTerminatorOperands = terminatorOperands; + Value sourceValue = immutableTerminatorOperands[operandIndex]; + // Create a new clone at the current location of the terminator. + auto clone = introduceCloneBuffers(sourceValue, terminator); + if (failed(clone)) + return failure(); + // Wire clone and terminator operand. + terminatorOperands.slice(operandIndex, 1).assign(*clone); + return success(); + }))) + return failure(); } + return success(); } /// Creates a new memory allocation for the given source value and clones /// its content into the newly allocated buffer. The terminator operation is /// used to insert the clone operation at the right place. - Value introduceCloneBuffers(Value sourceValue, Operation *terminator) { + FailureOr introduceCloneBuffers(Value sourceValue, + Operation *terminator) { // Avoid multiple clones of the same source value. This can happen in the // presence of loops when a branch acts as a backedge while also having // another successor that returns to its parent operation. Note: that @@ -422,19 +476,18 @@ return sourceValue; // Create a new clone operation that copies the contents of the old // buffer to the new one. - OpBuilder builder(terminator); - auto cloneOp = - builder.create(terminator->getLoc(), sourceValue); - - // Remember the clone of original source value. - clonedValues.insert(cloneOp); - return cloneOp; + auto clone = buildClone(terminator, sourceValue); + if (succeeded(clone)) { + // Remember the clone of original source value. + clonedValues.insert(*clone); + } + return clone; } /// Finds correct dealloc positions according to the algorithm described at /// the top of the file for all alloc nodes and block arguments that can be /// handled by this analysis. - void placeDeallocs() const { + LogicalResult placeDeallocs() { // Move or insert deallocs using the previously computed information. // These deallocations will be linked to their associated allocation nodes // since they don't have any aliases that can (potentially) increase their @@ -492,10 +545,54 @@ if (!nextOp) continue; // If there is no dealloc node, insert one in the right place. - OpBuilder builder(nextOp); - builder.create(alloc.getLoc(), alloc); + if (failed(buildDealloc(nextOp, alloc))) + return failure(); } } + return success(); + } + + /// Builds a deallocation operation compatible with the given allocation + /// value. If there is no registered AllocationOpInterface implementation for + /// the given value (e.g. in the case of a function parameter), this method + /// builds a memref::DeallocOp. + LogicalResult buildDealloc(Operation *op, Value alloc) { + OpBuilder builder(op); + auto it = aliasToAllocations.find(alloc); + if (it != aliasToAllocations.end()) { + // Call the allocation op interface to build a supported and + // compatible deallocation operation. + auto dealloc = it->second.buildDealloc(builder, alloc); + if (!dealloc) + return op->emitError() + << "allocations without compatible deallocations are " + "not supported"; + } else { + // Build a "default" DeallocOp for unknown allocation sources. + builder.create(alloc.getLoc(), alloc); + } + return success(); + } + + /// Builds a clone operation compatible with the given allocation value. If + /// there is no registered AllocationOpInterface implementation for the given + /// value (e.g. in the case of a function parameter), this method builds a + /// memref::CloneOp. + FailureOr buildClone(Operation *op, Value alloc) { + OpBuilder builder(op); + auto it = aliasToAllocations.find(alloc); + if (it != aliasToAllocations.end()) { + // Call the allocation op interface to build a supported and + // compatible clone operation. + auto clone = it->second.buildClone(builder, alloc); + if (clone) + return *clone; + return (LogicalResult)(op->emitError() + << "allocations without compatible clone ops " + "are not supported"); + } + // Build a "default" CloneOp for unknown allocation sources. + return builder.create(alloc.getLoc(), alloc).getResult(); } /// The dominator info to find the appropriate start operation to move the @@ -508,6 +605,9 @@ /// Stores already cloned buffers to avoid additional clones of clones. ValueSetT clonedValues; + + /// Maps aliases to their source allocation interfaces (inverse mapping). + AliasAllocationMapT aliasToAllocations; }; //===----------------------------------------------------------------------===// @@ -529,13 +629,20 @@ } // Check that the control flow structures are supported. - if (!validateSupportedControlFlow(func.getRegion())) { + if (!validateSupportedControlFlow(func.getRegion())) return signalPassFailure(); - } - // Place all required temporary clone and dealloc nodes. + // Gather all required allocation nodes and prepare the deallocation phase. BufferDeallocation deallocation(func); - deallocation.deallocate(); + + // Check for supported AllocationOpInterface implementations and prepare the + // internal deallocation pass. + if (failed(deallocation.prepare())) + return signalPassFailure(); + + // Place all required temporary clone and dealloc nodes. + if (failed(deallocation.deallocate())) + return signalPassFailure(); } };