diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -67,6 +67,22 @@ namespace { +/// Walks over all return-like terminators that either exit the parent region or +/// a nested region. +template +static void walkReturnOperations(Operation *operation, const FuncT &func) { + auto attachedRegions = operation->getRegions(); + operation->walk([&](Operation *terminator) { + // Skip non-return-like terminators or return-like ones that do not satisfy + // the escaping constraint. + if (terminator->hasTrait() && + (llvm::find_if(attachedRegions, [&](Region ®ion) { + return ®ion == terminator->getParentRegion(); + }) == attachedRegions.end()) != exitParentRegion) + func(terminator); + }); +} + //===----------------------------------------------------------------------===// // BufferPlacementAliasAnalysis //===----------------------------------------------------------------------===// @@ -82,7 +98,7 @@ public: /// Constructs a new alias analysis using the op provided. - BufferPlacementAliasAnalysis(Operation *op) { build(op->getRegions()); } + BufferPlacementAliasAnalysis(Operation *op) { build(op); } /// Find all immediate aliases this value could potentially have. ValueMapT::const_iterator find(Value value) const { @@ -102,7 +118,7 @@ } /// Removes the given values from all alias sets. - void remove(const SmallPtrSetImpl &aliasValues) { + void remove(const SmallPtrSetImpl &aliasValues) { for (auto &entry : aliases) llvm::set_subtract(entry.second, aliasValues); } @@ -120,36 +136,64 @@ resolveRecursive(alias, result); } + /// Registers a new alias tuple entry (first element is the alias, the second + /// one is the source value). + void registerAlias(std::tuple aliasEntry) { + aliases[std::get<1>(aliasEntry)].insert(std::get<0>(aliasEntry)); + } + /// This function constructs a mapping from values to its immediate aliases. /// It iterates over all blocks, gets their predecessors, determines the /// values that will be passed to the corresponding block arguments and - /// inserts them into the underlying map. - void build(MutableArrayRef regions) { - for (Region ®ion : regions) { - for (Block &block : region) { - // Iterate over all predecessor and get the mapped values to their - // corresponding block arguments values. - for (auto it = block.pred_begin(), e = block.pred_end(); it != e; - ++it) { - unsigned successorIndex = it.getSuccessorIndex(); - // Get the terminator and the values that will be passed to our block. - auto branchInterface = - dyn_cast((*it)->getTerminator()); - if (!branchInterface) - continue; - // Query the branch op interace to get the successor operands. - auto successorOperands = - branchInterface.getSuccessorOperands(successorIndex); - if (successorOperands.hasValue()) { - // Build the actual mapping of values to their immediate aliases. - for (auto argPair : llvm::zip(block.getArguments(), - successorOperands.getValue())) { - aliases[std::get<1>(argPair)].insert(std::get<0>(argPair)); - } - } + /// inserts them into the underlying map. Furthermore, it queries detailed + /// information about successor regions and branch-like return operations + /// from nested regions. + void build(Operation *op) { + op->walk([&](BranchOpInterface branchInterface) { + Block *parentBlock = branchInterface.getOperation()->getBlock(); + for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end(); + it != e; ++it) { + // Query the branch op interace to get the successor operands. + auto successorOperands = + branchInterface.getSuccessorOperands(it.getIndex()); + if (!successorOperands.hasValue()) + continue; + // Build the actual mapping of values to their immediate aliases. + for (auto argPair : + llvm::zip((*it)->getArguments(), successorOperands.getValue())) { + registerAlias(argPair); } } - } + }); + + // Query the RegionBranchOpInterface to find potential successor regions. + op->walk([&](RegionBranchOpInterface regionInterface) { + // Create an empty attribute for each operand to comply with the + // `getSuccessorRegions` interface definition that requires a single + // attribute per operand. + SmallVector operandAttributes( + regionInterface.getOperation()->getNumOperands()); + // Extract all region successors. + SmallVector regionSuccessors; + regionInterface.getSuccessorRegions(llvm::None, operandAttributes, + regionSuccessors); + for (auto regionSuccessor : regionSuccessors) { + Block *successorBlock = &*regionSuccessor.getSuccessor()->begin(); + for (auto argPair : llvm::zip(successorBlock->getArguments(), + regionSuccessor.getSuccessorInputs())) { + registerAlias(argPair); + } + } + }); + + // Query all return operations that leave nested regions. + walkReturnOperations(op, [&](Operation *terminator) { + Operation *parentOp = terminator->getParentOp(); + for (auto argPair : + llvm::zip(parentOp->getResults(), terminator->getOperands())) { + registerAlias(argPair); + } + }); } /// Maps values to all immediate aliases this value can have. @@ -235,14 +279,20 @@ Block *getInitialAllocBlock(OpResult result) { // Get all allocation operands as these operands are important for the // allocation operation. - auto operands = result.getOwner()->getOperands(); + Operation *owner = result.getOwner(); + auto operands = owner->getOperands(); if (operands.size() < 1) return findCommonDominator(result, aliases.resolve(result), dominators); // If this node has dependencies, check all dependent nodes with respect // to a common post dominator in which all values are available. ValueSetT dependencies(++operands.begin(), operands.end()); - return findCommonDominator(*operands.begin(), dependencies, postDominators); + Block *dominator = + findCommonDominator(*operands.begin(), dependencies, postDominators); + // Do not move allocs out of their parent regions to keep them local. + if (dominator->getParent() != owner->getParentRegion()) + return &*owner->getParentRegion()->begin(); + return dominator; } /// Finds correct alloc positions according to the algorithm described at @@ -273,12 +323,12 @@ /// Introduces required allocs and copy operations to avoid memory leaks. void introduceCopies() { - // Initialize the set of block arguments that require a dedicated memory - // free operation since their arguments cannot be safely deallocated in a - // post dominator. - SmallPtrSet blockArgsToFree; - llvm::SmallDenseSet> visitedBlockArgs; - SmallVector, 8> toProcess; + // Initialize the set of values that require a dedicated memory free + // operation since their operands cannot be safely deallocated in a post + // dominator. + SmallPtrSet valuesToFree; + llvm::SmallDenseSet> visitedValues; + SmallVector, 8> toProcess; // Check dominance relation for proper dominance properties. If the given // value node does not dominate an alias, we will have to create a copy in @@ -289,17 +339,15 @@ if (it == aliases.end()) return; for (Value value : it->second) { - auto blockArg = value.cast(); - if (blockArgsToFree.count(blockArg) > 0) + if (valuesToFree.count(value) > 0) continue; // Check whether we have to free this particular block argument. - if (!dominators.dominates(definingBlock, blockArg.getOwner())) { - toProcess.emplace_back(blockArg, blockArg.getParentBlock()); - blockArgsToFree.insert(blockArg); - } else if (visitedBlockArgs - .insert(std::make_tuple(blockArg, definingBlock)) + if (!dominators.dominates(definingBlock, value.getParentBlock())) { + toProcess.emplace_back(value, value.getParentBlock()); + valuesToFree.insert(value); + } else if (visitedValues.insert(std::make_tuple(value, definingBlock)) .second) - toProcess.emplace_back(blockArg, definingBlock); + toProcess.emplace_back(value, definingBlock); } }; @@ -316,60 +364,97 @@ // Update buffer aliases to ensure that we free all buffers and block // arguments at the correct locations. - aliases.remove(blockArgsToFree); + aliases.remove(valuesToFree); // Add new allocs and additional copy operations. - for (BlockArgument blockArg : blockArgsToFree) { - Block *block = blockArg.getOwner(); - - // Allocate a buffer for the current block argument in the block of - // the associated value (which will be a predecessor block by - // definition). - for (auto it = block->pred_begin(), e = block->pred_end(); it != e; - ++it) { - // Get the terminator and the value that will be passed to our - // argument. - Operation *terminator = (*it)->getTerminator(); - auto branchInterface = cast(terminator); - // Convert the mutable operand range to an immutable range and query the - // associated source value. - Value sourceValue = - branchInterface.getSuccessorOperands(it.getSuccessorIndex()) - .getValue()[blockArg.getArgNumber()]; - // Create a new alloc at the current location of the terminator. - auto memRefType = sourceValue.getType().cast(); - OpBuilder builder(terminator); - - // Extract information about dynamically shaped types by - // extracting their dynamic dimensions. - SmallVector dynamicOperands; - for (auto shapeElement : llvm::enumerate(memRefType.getShape())) { - if (!ShapedType::isDynamic(shapeElement.value())) - continue; - dynamicOperands.push_back(builder.create( - terminator->getLoc(), sourceValue, shapeElement.index())); - } + for (Value value : valuesToFree) { + if (auto blockArg = value.dyn_cast()) + introduceBlockArgCopy(blockArg); + else + introduceValueCopy(value); + + // Register the value to require a final dealloc. Note that we do not have + // to assign a block here since we do not want to move the allocation node + // to another location. + allocs.push_back({value, nullptr, nullptr}); + } + } - // TODO: provide a generic interface to create dialect-specific - // Alloc and CopyOp nodes. - auto alloc = builder.create(terminator->getLoc(), memRefType, - dynamicOperands); - // Wire new alloc and successor operand. - branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex()) - .getValue() - .slice(blockArg.getArgNumber(), 1) - .assign(alloc); - // Create a new copy operation that copies to contents of the old - // allocation to the new one. - builder.create(terminator->getLoc(), sourceValue, - alloc); - } + /// Introduces temporary allocs in all predecessors and copies the source + /// values into the newly allocated buffers. + void introduceBlockArgCopy(BlockArgument blockArg) { + // Allocate a buffer for the current block argument in the block of + // the associated value (which will be a predecessor block by + // definition). + Block *block = blockArg.getOwner(); + for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { + // Get the terminator and the value that will be passed to our + // argument. + Operation *terminator = (*it)->getTerminator(); + auto branchInterface = cast(terminator); + // Convert the mutable operand range to an immutable range and query the + // associated source value. + Value sourceValue = + branchInterface.getSuccessorOperands(it.getSuccessorIndex()) + .getValue()[blockArg.getArgNumber()]; + // Create a new alloc at the current location of the terminator. + Value alloc = introduceBufferCopy(sourceValue, terminator); + // Wire new alloc and successor operand. + branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex()) + .getValue() + .slice(blockArg.getArgNumber(), 1) + .assign(alloc); + } + } + + /// Introduces temporary allocs in front of all associated nested-region + /// terminators and copies the source values into the newly allocated buffers. + void introduceValueCopy(Value value) { + // Get the actual result index in the scope of the parent terminator. + Operation *operation = value.getDefiningOp(); + auto resultIndex = + llvm::find_if(operation->getResults(), [&](OpResult result) { + return result == value; + }).getIndex(); + + walkReturnOperations(operation, [&](Operation *terminator) { + // Extract the source value from the current terminator. + Value sourceValue = terminator->getOperand(resultIndex); + // Create a new alloc at the current location of the terminator. + Value alloc = introduceBufferCopy(sourceValue, terminator); + // Wire alloc and terminator operand. + terminator->setOperand(resultIndex, alloc); + }); + } - // Register the block argument to require a final dealloc. Note that - // we do not have to assign a block here since we do not want to - // move the allocation node to another location. - allocs.push_back({blockArg, nullptr, nullptr}); + /// Creates a new memory allocation for the given source value and copies its + /// content into the newly allocated buffer. The terminator operation is used + /// to insert the alloc and copy operations at the right places. + Value introduceBufferCopy(Value sourceValue, Operation *terminator) { + // Create a new alloc at the current location of the terminator. + auto memRefType = sourceValue.getType().cast(); + OpBuilder builder(terminator); + + // Extract information about dynamically shaped types by + // extracting their dynamic dimensions. + SmallVector dynamicOperands; + for (auto shapeElement : llvm::enumerate(memRefType.getShape())) { + if (!ShapedType::isDynamic(shapeElement.value())) + continue; + dynamicOperands.push_back(builder.create( + terminator->getLoc(), sourceValue, shapeElement.index())); } + + // TODO: provide a generic interface to create dialect-specific + // Alloc and CopyOp nodes. + auto alloc = builder.create(terminator->getLoc(), memRefType, + dynamicOperands); + + // Create a new copy operation that copies to contents of the old + // allocation to the new one. + builder.create(terminator->getLoc(), sourceValue, alloc); + + return alloc; } /// Finds associated deallocs that can be linked to our allocation nodes (if diff --git a/mlir/test/Transforms/buffer-placement.mlir b/mlir/test/Transforms/buffer-placement.mlir --- a/mlir/test/Transforms/buffer-placement.mlir +++ b/mlir/test/Transforms/buffer-placement.mlir @@ -716,3 +716,66 @@ // CHECK: dealloc %[[Y]] // CHECK: return %[[ARG1]], %[[X]] +// ----- + +// Test Case: nested region control flow +// The alloc position of %1 does not need to be changed and flows through +// both if branches until it is finally returned. Hence, it does not +// require a specific dealloc operation. However, %3 requires a dealloc. + +func @nested_region_control_flow( + %arg0 : index, + %arg1 : index) -> memref { + %0 = cmpi "eq", %arg0, %arg1 : index + %1 = alloc(%arg0, %arg0) : memref + %2 = scf.if %0 -> (memref) { + scf.yield %1 : memref + } else { + %3 = alloc(%arg0, %arg1) : memref + scf.yield %1 : memref + } + return %2 : memref +} + +// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0) +// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if +// CHECK: scf.yield %[[ALLOC0]] +// CHECK: %[[ALLOC2:.*]] = alloc(%arg0, %arg1) +// CHECK-NEXT: dealloc %[[ALLOC2]] +// CHECK-NEXT: scf.yield %[[ALLOC0]] +// CHECK: return %[[ALLOC1]] + +// ----- + +// Test Case: nested region control flow with a nested buffer allocation in a +// divergent branch. +// The alloc positions of %1, %3 does not need to be changed since +// BufferPlacement does not move allocs out of nested regions at the moment. +// However, since %3 is allocated and "returned" in a divergent branch, we have +// to allocate a temporary buffer (like in condBranchDynamicTypeNested). + +func @nested_region_control_flow_div( + %arg0 : index, + %arg1 : index) -> memref { + %0 = cmpi "eq", %arg0, %arg1 : index + %1 = alloc(%arg0, %arg0) : memref + %2 = scf.if %0 -> (memref) { + scf.yield %1 : memref + } else { + %3 = alloc(%arg0, %arg1) : memref + scf.yield %3 : memref + } + return %2 : memref +} + +// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0) +// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if +// CHECK: %[[ALLOC2:.*]] = alloc +// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[ALLOC2]]) +// CHECK: scf.yield %[[ALLOC2]] +// CHECK: %[[ALLOC3:.*]] = alloc(%arg0, %arg1) +// CHECK: %[[ALLOC4:.*]] = alloc +// CHECK-NEXT: linalg.copy(%[[ALLOC3]], %[[ALLOC4]]) +// CHECK: scf.yield %[[ALLOC4]] +// CHECK: dealloc %[[ALLOC0]] +// CHECK-NEXT: return %[[ALLOC1]]