diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -65,8 +65,18 @@ using namespace mlir; -namespace { +/// Walks over all immediate return-like terminators in the given region. +template +static void walkReturnOperations(Region *region, const FuncT &func) { + for (Block &block : *region) + for (Operation &operation : block) { + // Skip non-return-like terminators. + if (operation.hasTrait()) + func(&operation); + } +} +namespace { //===----------------------------------------------------------------------===// // BufferPlacementAliasAnalysis //===----------------------------------------------------------------------===// @@ -82,7 +92,7 @@ public: /// Constructs a new alias analysis using the op provided. - BufferPlacementAliasAnalysis(Operation *op) { build(op->getRegions()); } + BufferPlacementAliasAnalysis(Operation *op) { build(op); } /// Find all immediate aliases this value could potentially have. ValueMapT::const_iterator find(Value value) const { @@ -102,7 +112,7 @@ } /// Removes the given values from all alias sets. - void remove(const SmallPtrSetImpl &aliasValues) { + void remove(const SmallPtrSetImpl &aliasValues) { for (auto &entry : aliases) llvm::set_subtract(entry.second, aliasValues); } @@ -123,33 +133,69 @@ /// This function constructs a mapping from values to its immediate aliases. /// It iterates over all blocks, gets their predecessors, determines the /// values that will be passed to the corresponding block arguments and - /// inserts them into the underlying map. - void build(MutableArrayRef regions) { - for (Region ®ion : regions) { - for (Block &block : region) { - // Iterate over all predecessor and get the mapped values to their - // corresponding block arguments values. - for (auto it = block.pred_begin(), e = block.pred_end(); it != e; - ++it) { - unsigned successorIndex = it.getSuccessorIndex(); - // Get the terminator and the values that will be passed to our block. - auto branchInterface = - dyn_cast((*it)->getTerminator()); - if (!branchInterface) - continue; - // Query the branch op interace to get the successor operands. - auto successorOperands = - branchInterface.getSuccessorOperands(successorIndex); - if (successorOperands.hasValue()) { - // Build the actual mapping of values to their immediate aliases. - for (auto argPair : llvm::zip(block.getArguments(), - successorOperands.getValue())) { - aliases[std::get<1>(argPair)].insert(std::get<0>(argPair)); - } - } + /// inserts them into the underlying map. Furthermore, it wires successor + /// regions and branch-like return operations from nested regions. + void build(Operation *op) { + // Registers all aliases of the given values. + auto registerAliases = [&](auto values, auto aliases) { + for (auto entry : llvm::zip(values, aliases)) + this->aliases[std::get<0>(entry)].insert(std::get<1>(entry)); + }; + + // Query all branch interfaces to link block argument aliases. + op->walk([&](BranchOpInterface branchInterface) { + Block *parentBlock = branchInterface.getOperation()->getBlock(); + for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end(); + it != e; ++it) { + // Query the branch op interface to get the successor operands. + auto successorOperands = + branchInterface.getSuccessorOperands(it.getIndex()); + if (!successorOperands.hasValue()) + continue; + // Build the actual mapping of values to their immediate aliases. + registerAliases(successorOperands.getValue(), (*it)->getArguments()); + } + }); + + // Query the RegionBranchOpInterface to find potential successor regions. + op->walk([&](RegionBranchOpInterface regionInterface) { + // Create an empty attribute for each operand to comply with the + // `getSuccessorRegions` interface definition that requires a single + // attribute per operand. + SmallVector operandAttributes( + regionInterface.getOperation()->getNumOperands()); + + // Extract all entry regions and wire all initial entry successor inputs. + SmallVector entrySuccessors; + regionInterface.getSuccessorRegions(/*index=*/llvm::None, + operandAttributes, entrySuccessors); + for (RegionSuccessor &entrySuccessor : entrySuccessors) { + // Wire the entry region's successor arguments with the initial + // successor inputs. + assert(entrySuccessor.getSuccessor() && + "Invalid entry region without an attached successor region"); + registerAliases(regionInterface.getSuccessorEntryOperands( + entrySuccessor.getSuccessor()->getRegionNumber()), + entrySuccessor.getSuccessorInputs()); + } + + // Wire flow between regions and from region exits. + for (Region ®ion : regionInterface.getOperation()->getRegions()) { + // Iterate over all successor region entries that are reachable from the + // current region. + SmallVector successorRegions; + regionInterface.getSuccessorRegions( + region.getRegionNumber(), operandAttributes, successorRegions); + for (RegionSuccessor &successorRegion : successorRegions) { + // Iterate over all immediate terminator operations and wire the + // successor inputs with the operands of each terminator. + walkReturnOperations(®ion, [&](Operation *terminator) { + registerAliases(terminator->getOperands(), + successorRegion.getSuccessorInputs()); + }); } } - } + }); } /// Maps values to all immediate aliases this value can have. @@ -235,14 +281,24 @@ Block *getInitialAllocBlock(OpResult result) { // Get all allocation operands as these operands are important for the // allocation operation. - auto operands = result.getOwner()->getOperands(); + Operation *owner = result.getOwner(); + auto operands = owner->getOperands(); + Block *dominator; if (operands.size() < 1) - return findCommonDominator(result, aliases.resolve(result), dominators); + dominator = + findCommonDominator(result, aliases.resolve(result), dominators); + else { + // If this node has dependencies, check all dependent nodes with respect + // to a common post dominator in which all values are available. + ValueSetT dependencies(++operands.begin(), operands.end()); + dominator = + findCommonDominator(*operands.begin(), dependencies, postDominators); + } - // If this node has dependencies, check all dependent nodes with respect - // to a common post dominator in which all values are available. - ValueSetT dependencies(++operands.begin(), operands.end()); - return findCommonDominator(*operands.begin(), dependencies, postDominators); + // Do not move allocs out of their parent regions to keep them local. + if (dominator->getParent() != owner->getParentRegion()) + return &owner->getParentRegion()->front(); + return dominator; } /// Finds correct alloc positions according to the algorithm described at @@ -273,12 +329,12 @@ /// Introduces required allocs and copy operations to avoid memory leaks. void introduceCopies() { - // Initialize the set of block arguments that require a dedicated memory - // free operation since their arguments cannot be safely deallocated in a - // post dominator. - SmallPtrSet blockArgsToFree; - llvm::SmallDenseSet> visitedBlockArgs; - SmallVector, 8> toProcess; + // Initialize the set of values that require a dedicated memory free + // operation since their operands cannot be safely deallocated in a post + // dominator. + SmallPtrSet valuesToFree; + llvm::SmallDenseSet> visitedValues; + SmallVector, 8> toProcess; // Check dominance relation for proper dominance properties. If the given // value node does not dominate an alias, we will have to create a copy in @@ -289,17 +345,15 @@ if (it == aliases.end()) return; for (Value value : it->second) { - auto blockArg = value.cast(); - if (blockArgsToFree.count(blockArg) > 0) + if (valuesToFree.count(value) > 0) continue; // Check whether we have to free this particular block argument. - if (!dominators.dominates(definingBlock, blockArg.getOwner())) { - toProcess.emplace_back(blockArg, blockArg.getParentBlock()); - blockArgsToFree.insert(blockArg); - } else if (visitedBlockArgs - .insert(std::make_tuple(blockArg, definingBlock)) + if (!dominators.dominates(definingBlock, value.getParentBlock())) { + toProcess.emplace_back(value, value.getParentBlock()); + valuesToFree.insert(value); + } else if (visitedValues.insert(std::make_tuple(value, definingBlock)) .second) - toProcess.emplace_back(blockArg, definingBlock); + toProcess.emplace_back(value, definingBlock); } }; @@ -316,62 +370,170 @@ // Update buffer aliases to ensure that we free all buffers and block // arguments at the correct locations. - aliases.remove(blockArgsToFree); + aliases.remove(valuesToFree); // Add new allocs and additional copy operations. - for (BlockArgument blockArg : blockArgsToFree) { - Block *block = blockArg.getOwner(); - - // Allocate a buffer for the current block argument in the block of - // the associated value (which will be a predecessor block by - // definition). - for (auto it = block->pred_begin(), e = block->pred_end(); it != e; - ++it) { - // Get the terminator and the value that will be passed to our - // argument. - Operation *terminator = (*it)->getTerminator(); - auto branchInterface = cast(terminator); - // Convert the mutable operand range to an immutable range and query the - // associated source value. - Value sourceValue = - branchInterface.getSuccessorOperands(it.getSuccessorIndex()) - .getValue()[blockArg.getArgNumber()]; - // Create a new alloc at the current location of the terminator. - auto memRefType = sourceValue.getType().cast(); - OpBuilder builder(terminator); - - // Extract information about dynamically shaped types by - // extracting their dynamic dimensions. - SmallVector dynamicOperands; - for (auto shapeElement : llvm::enumerate(memRefType.getShape())) { - if (!ShapedType::isDynamic(shapeElement.value())) - continue; - dynamicOperands.push_back(builder.create( - terminator->getLoc(), sourceValue, shapeElement.index())); - } + for (Value value : valuesToFree) { + if (auto blockArg = value.dyn_cast()) + introduceBlockArgCopy(blockArg); + else + introduceValueCopyForRegionResult(value); + + // Register the value to require a final dealloc. Note that we do not have + // to assign a block here since we do not want to move the allocation node + // to another location. + allocs.push_back({value, nullptr, nullptr}); + } + } - // TODO: provide a generic interface to create dialect-specific - // Alloc and CopyOp nodes. - auto alloc = builder.create(terminator->getLoc(), memRefType, - dynamicOperands); - // Wire new alloc and successor operand. - branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex()) - .getValue() + /// Introduces temporary allocs in all predecessors and copies the source + /// values into the newly allocated buffers. + void introduceBlockArgCopy(BlockArgument blockArg) { + // Allocate a buffer for the current block argument in the block of + // the associated value (which will be a predecessor block by + // definition). + Block *block = blockArg.getOwner(); + for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { + // Get the terminator and the value that will be passed to our + // argument. + Operation *terminator = (*it)->getTerminator(); + auto branchInterface = cast(terminator); + // Query the associated source value. + Value sourceValue = + branchInterface.getSuccessorOperands(it.getSuccessorIndex()) + .getValue()[blockArg.getArgNumber()]; + // Create a new alloc and copy at the current location of the terminator. + Value alloc = introduceBufferCopy(sourceValue, terminator); + // Wire new alloc and successor operand. + auto mutableOperands = + branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex()); + if (!mutableOperands.hasValue()) + terminator->emitError() << "terminators with immutable successor " + "operands are not supported"; + else + mutableOperands.getValue() .slice(blockArg.getArgNumber(), 1) .assign(alloc); - // Create a new copy operation that copies to contents of the old - // allocation to the new one. - builder.create(terminator->getLoc(), sourceValue, - alloc); - } + } - // Register the block argument to require a final dealloc. Note that - // we do not have to assign a block here since we do not want to - // move the allocation node to another location. - allocs.push_back({blockArg, nullptr, nullptr}); + // Check whether the block argument has implicitly defined predecessors via + // the RegionBranchOpInterface. This can be the case if the current block + // argument belongs to the first block in a region and the parent operation + // implements the RegionBranchOpInterface. + Region *argRegion = block->getParent(); + RegionBranchOpInterface regionInterface; + if (!argRegion || &argRegion->front() != block || + !(regionInterface = + dyn_cast(argRegion->getParentOp()))) + return; + + introduceCopiesForRegionSuccessors( + regionInterface, argRegion->getParentOp()->getRegions(), + [&](RegionSuccessor &successorRegion) { + // Find a predecessor of our argRegion. + return successorRegion.getSuccessor() == argRegion; + }, + [&](RegionSuccessor &successorRegion) { + // The operand index will be the argument number. + return blockArg.getArgNumber(); + }); + } + + /// Introduces temporary allocs in front of all associated nested-region + /// terminators and copies the source values into the newly allocated buffers. + void introduceValueCopyForRegionResult(Value value) { + // Get the actual result index in the scope of the parent terminator. + Operation *operation = value.getDefiningOp(); + auto regionInterface = cast(operation); + introduceCopiesForRegionSuccessors( + regionInterface, operation->getRegions(), + [&](RegionSuccessor &successorRegion) { + // Determine whether this region has a successor entry that leaves + // this region by returning to its parent operation. + return !successorRegion.getSuccessor(); + }, + [&](RegionSuccessor &successorRegion) { + // Find the associated success input index. + return llvm::find(successorRegion.getSuccessorInputs(), value) + .getIndex(); + }); + } + + /// Introduces buffer copies for all terminators in the given regions. The + /// regionPredicate is applied to every successor region in order to restrict + /// the copies to specific regions. Thereby, the operandProvider is invoked + /// for each matching region successor and determines the operand index that + /// requires a buffer copy. + template + void + introduceCopiesForRegionSuccessors(RegionBranchOpInterface regionInterface, + MutableArrayRef regions, + const TPredicate ®ionPredicate, + const TOperandProvider &operandProvider) { + // Create an empty attribute for each operand to comply with the + // `getSuccessorRegions` interface definition that requires a single + // attribute per operand. + SmallVector operandAttributes( + regionInterface.getOperation()->getNumOperands()); + for (Region ®ion : regions) { + // Query the regionInterface to get all successor regions of the current + // one. + SmallVector successorRegions; + regionInterface.getSuccessorRegions(region.getRegionNumber(), + operandAttributes, successorRegions); + // Try to find a matching region successor. + RegionSuccessor *regionSuccessor = + llvm::find_if(successorRegions, regionPredicate); + if (regionSuccessor == successorRegions.end()) + continue; + // Get the operand index in the context of the current successor input + // bindings. + auto operandIndex = operandProvider(*regionSuccessor); + + // Iterate over all immediate terminator operations to introduce + // new buffer allocations. Thereby, the appropriate terminator operand + // will be adjusted to point to the newly allocated buffer instead. + walkReturnOperations(®ion, [&](Operation *terminator) { + // Extract the source value from the current terminator. + Value sourceValue = terminator->getOperand(operandIndex); + // Create a new alloc at the current location of the terminator. + Value alloc = introduceBufferCopy(sourceValue, terminator); + // Wire alloc and terminator operand. + terminator->setOperand(operandIndex, alloc); + }); } } + /// Creates a new memory allocation for the given source value and copies + /// its content into the newly allocated buffer. The terminator operation is + /// used to insert the alloc and copy operations at the right places. + Value introduceBufferCopy(Value sourceValue, Operation *terminator) { + // Create a new alloc at the current location of the terminator. + auto memRefType = sourceValue.getType().cast(); + OpBuilder builder(terminator); + + // Extract information about dynamically shaped types by + // extracting their dynamic dimensions. + SmallVector dynamicOperands; + for (auto shapeElement : llvm::enumerate(memRefType.getShape())) { + if (!ShapedType::isDynamic(shapeElement.value())) + continue; + dynamicOperands.push_back(builder.create( + terminator->getLoc(), sourceValue, shapeElement.index())); + } + + // TODO: provide a generic interface to create dialect-specific + // Alloc and CopyOp nodes. + auto alloc = builder.create(terminator->getLoc(), memRefType, + dynamicOperands); + + // Create a new copy operation that copies to contents of the old + // allocation to the new one. + builder.create(terminator->getLoc(), sourceValue, alloc); + + return alloc; + } + /// Finds associated deallocs that can be linked to our allocation nodes (if /// any). void findDeallocs() { @@ -440,8 +602,8 @@ if (entry.deallocOperation) { entry.deallocOperation->moveAfter(endOperation); } else { - // If the Dealloc position is at the terminator operation of the block, - // then the value should escape from a deallocation. + // If the Dealloc position is at the terminator operation of the + // block, then the value should escape from a deallocation. Operation *nextOp = endOperation->getNextNode(); if (!nextOp) continue; diff --git a/mlir/test/Transforms/buffer-placement.mlir b/mlir/test/Transforms/buffer-placement.mlir --- a/mlir/test/Transforms/buffer-placement.mlir +++ b/mlir/test/Transforms/buffer-placement.mlir @@ -716,3 +716,201 @@ // CHECK: dealloc %[[Y]] // CHECK: return %[[ARG1]], %[[X]] +// ----- + +// Test Case: nested region control flow +// The alloc position of %1 does not need to be changed and flows through +// both if branches until it is finally returned. Hence, it does not +// require a specific dealloc operation. However, %3 requires a dealloc. + +// CHECK-LABEL: func @nested_region_control_flow +func @nested_region_control_flow( + %arg0 : index, + %arg1 : index) -> memref { + %0 = cmpi "eq", %arg0, %arg1 : index + %1 = alloc(%arg0, %arg0) : memref + %2 = scf.if %0 -> (memref) { + scf.yield %1 : memref + } else { + %3 = alloc(%arg0, %arg1) : memref + scf.yield %1 : memref + } + return %2 : memref +} + +// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0) +// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if +// CHECK: scf.yield %[[ALLOC0]] +// CHECK: %[[ALLOC2:.*]] = alloc(%arg0, %arg1) +// CHECK-NEXT: dealloc %[[ALLOC2]] +// CHECK-NEXT: scf.yield %[[ALLOC0]] +// CHECK: return %[[ALLOC1]] + +// ----- + +// Test Case: nested region control flow with a nested buffer allocation in a +// divergent branch. +// The alloc positions of %1, %3 does not need to be changed since +// BufferPlacement does not move allocs out of nested regions at the moment. +// However, since %3 is allocated and "returned" in a divergent branch, we have +// to allocate a temporary buffer (like in condBranchDynamicTypeNested). + +// CHECK-LABEL: func @nested_region_control_flow_div +func @nested_region_control_flow_div( + %arg0 : index, + %arg1 : index) -> memref { + %0 = cmpi "eq", %arg0, %arg1 : index + %1 = alloc(%arg0, %arg0) : memref + %2 = scf.if %0 -> (memref) { + scf.yield %1 : memref + } else { + %3 = alloc(%arg0, %arg1) : memref + scf.yield %3 : memref + } + return %2 : memref +} + +// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0) +// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if +// CHECK: %[[ALLOC2:.*]] = alloc +// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[ALLOC2]]) +// CHECK: scf.yield %[[ALLOC2]] +// CHECK: %[[ALLOC3:.*]] = alloc(%arg0, %arg1) +// CHECK: %[[ALLOC4:.*]] = alloc +// CHECK-NEXT: linalg.copy(%[[ALLOC3]], %[[ALLOC4]]) +// CHECK: dealloc %[[ALLOC3]] +// CHECK: scf.yield %[[ALLOC4]] +// CHECK: dealloc %[[ALLOC0]] +// CHECK-NEXT: return %[[ALLOC1]] + +// ----- + +// Test Case: deeply nested region control flow with a nested buffer allocation +// in a divergent branch. +// The alloc positions of %1, %4 and %5 does not need to be changed since +// BufferPlacement does not move allocs out of nested regions at the moment. +// However, since %4 is allocated and "returned" in a divergent branch, we have +// to allocate several temporary buffers (like in condBranchDynamicTypeNested). + +// CHECK-LABEL: func @nested_region_control_flow_div_nested +func @nested_region_control_flow_div_nested( + %arg0 : index, + %arg1 : index) -> memref { + %0 = cmpi "eq", %arg0, %arg1 : index + %1 = alloc(%arg0, %arg0) : memref + %2 = scf.if %0 -> (memref) { + %3 = scf.if %0 -> (memref) { + scf.yield %1 : memref + } else { + %4 = alloc(%arg0, %arg1) : memref + scf.yield %4 : memref + } + scf.yield %3 : memref + } else { + %5 = alloc(%arg1, %arg1) : memref + scf.yield %5 : memref + } + return %2 : memref +} +// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0) +// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if +// CHECK-NEXT: %[[ALLOC2:.*]] = scf.if +// CHECK: %[[ALLOC3:.*]] = alloc +// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[ALLOC3]]) +// CHECK: scf.yield %[[ALLOC3]] +// CHECK: %[[ALLOC4:.*]] = alloc(%arg0, %arg1) +// CHECK: %[[ALLOC5:.*]] = alloc +// CHECK-NEXT: linalg.copy(%[[ALLOC4]], %[[ALLOC5]]) +// CHECK: dealloc %[[ALLOC4]] +// CHECK: scf.yield %[[ALLOC5]] +// CHECK: %[[ALLOC6:.*]] = alloc +// CHECK-NEXT: linalg.copy(%[[ALLOC2]], %[[ALLOC6]]) +// CHECK: dealloc %[[ALLOC2]] +// CHECK: scf.yield %[[ALLOC6]] +// CHECK: %[[ALLOC7:.*]] = alloc(%arg1, %arg1) +// CHECK: %[[ALLOC8:.*]] = alloc +// CHECK-NEXT: linalg.copy(%[[ALLOC7]], %[[ALLOC8]]) +// CHECK: dealloc %[[ALLOC7]] +// CHECK: scf.yield %[[ALLOC8]] +// CHECK: dealloc %[[ALLOC0]] +// CHECK-NEXT: return %[[ALLOC1]] + +// ----- + +// Test Case: nested region control flow within a region interface. +// The alloc positions of %0 does not need to be changed and no copies are +// required in this case since the allocation finally escapes the method. + +// CHECK-LABEL: func @inner_region_control_flow +func @inner_region_control_flow(%arg0 : index) -> memref { + %0 = alloc(%arg0, %arg0) : memref + %1 = test.region_if %0 : memref -> (memref) then { + ^bb0(%arg1 : memref): + test.region_if_yield %arg1 : memref + } else { + ^bb0(%arg1 : memref): + test.region_if_yield %arg1 : memref + } join { + ^bb0(%arg1 : memref): + test.region_if_yield %arg1 : memref + } + return %1 : memref +} + +// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0) +// CHECK-NEXT: %[[ALLOC1:.*]] = test.region_if +// CHECK-NEXT: ^bb0(%[[ALLOC2:.*]]:{{.*}}): +// CHECK-NEXT: test.region_if_yield %[[ALLOC2]] +// CHECK: ^bb0(%[[ALLOC3:.*]]:{{.*}}): +// CHECK-NEXT: test.region_if_yield %[[ALLOC3]] +// CHECK: ^bb0(%[[ALLOC4:.*]]:{{.*}}): +// CHECK-NEXT: test.region_if_yield %[[ALLOC4]] +// CHECK: return %[[ALLOC1]] + +// ----- + +// Test Case: nested region control flow within a region interface including an +// allocation in a divergent branch. +// The alloc positions of %1 and %2 does not need to be changed since +// BufferPlacement does not move allocs out of nested regions at the moment. +// However, since %2 is allocated and yielded in a divergent branch, we have +// to allocate several temporary buffers (like in condBranchDynamicTypeNested). + +// CHECK-LABEL: func @inner_region_control_flow_div +func @inner_region_control_flow_div( + %arg0 : index, + %arg1 : index) -> memref { + %0 = alloc(%arg0, %arg0) : memref + %1 = test.region_if %0 : memref -> (memref) then { + ^bb0(%arg2 : memref): + test.region_if_yield %arg2 : memref + } else { + ^bb0(%arg2 : memref): + %2 = alloc(%arg0, %arg1) : memref + test.region_if_yield %2 : memref + } join { + ^bb0(%arg2 : memref): + test.region_if_yield %arg2 : memref + } + return %1 : memref +} + +// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0) +// CHECK-NEXT: %[[ALLOC1:.*]] = test.region_if +// CHECK-NEXT: ^bb0(%[[ALLOC2:.*]]:{{.*}}): +// CHECK: %[[ALLOC3:.*]] = alloc +// CHECK-NEXT: linalg.copy(%[[ALLOC2]], %[[ALLOC3]]) +// CHECK-NEXT: test.region_if_yield %[[ALLOC3]] +// CHECK: ^bb0(%[[ALLOC4:.*]]:{{.*}}): +// CHECK: %[[ALLOC5:.*]] = alloc +// CHECK: %[[ALLOC6:.*]] = alloc +// CHECK-NEXT: linalg.copy(%[[ALLOC5]], %[[ALLOC6]]) +// CHECK-NEXT: dealloc %[[ALLOC5]] +// CHECK-NEXT: test.region_if_yield %[[ALLOC6]] +// CHECK: ^bb0(%[[ALLOC7:.*]]:{{.*}}): +// CHECK: %[[ALLOC8:.*]] = alloc +// CHECK-NEXT: linalg.copy(%[[ALLOC7]], %[[ALLOC8]]) +// CHECK-NEXT: dealloc %[[ALLOC7]] +// CHECK-NEXT: test.region_if_yield %[[ALLOC8]] +// CHECK: dealloc %[[ALLOC0]] +// CHECK-NEXT: return %[[ALLOC1]] diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -518,6 +518,77 @@ setNameFn(getResult(i), str.getValue()); } +//===----------------------------------------------------------------------===// +// RegionIfOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, RegionIfOp op) { + p << RegionIfOp::getOperationName() << " "; + p.printOperands(op.getOperands()); + p << ": " << op.getOperandTypes(); + p.printArrowTypeList(op.getResultTypes()); + p << " then"; + p.printRegion(op.thenRegion(), + /*printEntryBlockArgs=*/true, + /*printBlockTerminators=*/true); + p << " else"; + p.printRegion(op.elseRegion(), + /*printEntryBlockArgs=*/true, + /*printBlockTerminators=*/true); + p << " join"; + p.printRegion(op.joinRegion(), + /*printEntryBlockArgs=*/true, + /*printBlockTerminators=*/true); +} + +static ParseResult parseRegionIfOp(OpAsmParser &parser, + OperationState &result) { + SmallVector operandInfos; + SmallVector operandTypes; + + result.regions.reserve(3); + Region *thenRegion = result.addRegion(); + Region *elseRegion = result.addRegion(); + Region *joinRegion = result.addRegion(); + + // Parse operand, type and arrow type lists. + if (parser.parseOperandList(operandInfos) || + parser.parseColonTypeList(operandTypes) || + parser.parseArrowTypeList(result.types)) + return failure(); + + // Parse all attached regions. + if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || + parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || + parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) + return failure(); + + return parser.resolveOperands(operandInfos, operandTypes, + parser.getCurrentLocation(), result.operands); +} + +OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { + assert(index < 2 && "invalid region index"); + return getOperands(); +} + +void RegionIfOp::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl ®ions) { + // We always branch to the join region. + if (index.hasValue()) { + if (index.getValue() < 2) + regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); + else + regions.push_back(RegionSuccessor(getResults())); + return; + } + + // The then and else regions are the entry regions of this op. + regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); + regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); +} + //===----------------------------------------------------------------------===// // Dialect Registration //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1343,4 +1343,47 @@ let results = (outs AnyType:$result); } +//===----------------------------------------------------------------------===// +// Test RegionBranchOpInterface +//===----------------------------------------------------------------------===// + +def RegionIfYieldOp : TEST_Op<"region_if_yield", + [NoSideEffect, ReturnLike, Terminator]> { + let arguments = (ins Variadic:$results); + let assemblyFormat = [{ + $results `:` type($results) attr-dict + }]; +} + +def RegionIfOp : TEST_Op<"region_if", + [DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"RegionIfYieldOp">, + RecursiveSideEffects]> { + let description =[{ + Represents an abstract if-then-else-join pattern. In this context, the then + and else regions jump to the join region, which finally returns to its + parent op. + }]; + + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parseRegionIfOp(parser, result); }]; + let arguments = (ins Variadic); + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$thenRegion, + AnyRegion:$elseRegion, + AnyRegion:$joinRegion); + let extraClassDeclaration = [{ + Block::BlockArgListType getThenArgs() { + return getBody(0)->getArguments(); + } + Block::BlockArgListType getElseArgs() { + return getBody(1)->getArguments(); + } + Block::BlockArgListType getJoinArgs() { + return getBody(2)->getArguments(); + } + OperandRange getSuccessorEntryOperands(unsigned index); + }]; +} + #endif // TEST_OPS