diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h @@ -353,8 +353,8 @@ /// any effect on the lattice that isn't already expressed by the interface /// itself. virtual void visitRegionBranchControlFlowTransfer( - RegionBranchOpInterface branch, std::optional regionFrom, - std::optional regionTo, const AbstractDenseLattice &after, + RegionBranchOpInterface branch, RegionBranchPoint regionFrom, + RegionBranchPoint regionTo, const AbstractDenseLattice &after, AbstractDenseLattice *before) { meet(before, after); } @@ -382,7 +382,7 @@ /// of the branch operation itself. void visitRegionBranchOperation(ProgramPoint point, RegionBranchOpInterface branch, - std::optional regionNo, + RegionBranchPoint branchPoint, AbstractDenseLattice *before); /// Visit an operation for which the data flow is described by the @@ -472,9 +472,8 @@ /// nullptr`. The behavior can be further refined for specific pairs of "from" /// and "to" regions. virtual void visitRegionBranchControlFlowTransfer( - RegionBranchOpInterface branch, std::optional regionFrom, - std::optional regionTo, const LatticeT &after, - LatticeT *before) { + RegionBranchOpInterface branch, RegionBranchPoint regionFrom, + RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) { AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer( branch, regionFrom, regionTo, after, before); } @@ -508,8 +507,8 @@ static_cast(before)); } void visitRegionBranchControlFlowTransfer( - RegionBranchOpInterface branch, std::optional regionForm, - std::optional regionTo, const AbstractDenseLattice &after, + RegionBranchOpInterface branch, RegionBranchPoint regionForm, + RegionBranchPoint regionTo, const AbstractDenseLattice &after, AbstractDenseLattice *before) final { visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo, static_cast(after), diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -243,7 +243,7 @@ /// regions or the parent operation itself, and set either the argument or /// parent result lattices. void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch, - std::optional successorIndex, + RegionBranchPoint successor, ArrayRef lattices); }; diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -190,6 +190,68 @@ ValueRange inputs; }; +/// This class represents a point being branched from in the methods of the +/// `RegionBranchOpInterface`. +/// One can branch from one of two kinds of places: +/// * The parent operation (aka the `RegionBranchOpInterface` implementation) +/// * A region within the parent operation. +class RegionBranchPoint { +public: + /// Returns an instance of `RegionBranchPoint` representing the parent + /// operation. + static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); } + + /// Creates a `RegionBranchPoint` that branches from the given region. + /// The pointer must not be null. + RegionBranchPoint(Region *region) : maybeRegion(region) { + assert(region && "Region must not be null"); + } + + RegionBranchPoint(Region ®ion) : RegionBranchPoint(®ion) {} + + /// Explicitly stops users from constructing with `nullptr`. + RegionBranchPoint(std::nullptr_t) = delete; + + /// Constructs a `RegionBranchPoint` from the the target of a + /// `RegionSuccessor` instance. + RegionBranchPoint(RegionSuccessor successor) { + if (successor.isParent()) + maybeRegion = nullptr; + else + maybeRegion = successor.getSuccessor(); + } + + /// Assigns a region being branched from. + RegionBranchPoint &operator=(Region ®ion) { + maybeRegion = ®ion; + return *this; + } + + /// Returns true if branching from the parent op. + bool isParent() const { return maybeRegion == nullptr; } + + /// Returns the region if branching from a region. + /// A null pointer otherwise. + Region *getRegionOrNull() const { return maybeRegion; } + + /// Returns true if the two branch points are equal. + friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) { + return lhs.maybeRegion == rhs.maybeRegion; + } + +private: + // Private constructor to encourage the use of `RegionBranchPoint::parent`. + constexpr RegionBranchPoint() : maybeRegion(nullptr) {} + + /// Internal encoding. Uses nullptr for representing branching from the parent + /// op and the region being branched from otherwise. + Region *maybeRegion; +}; + +inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) { + return !(lhs == rhs); +} + /// This class represents upper and lower bounds on the number of times a region /// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least /// zero, but the upper bound may not be known. diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -133,14 +133,14 @@ let methods = [ InterfaceMethod<[{ Returns the operands of this operation used as the entry arguments when - entering the region at `index`, which was specified as a successor of + branching from `point`, which was specified as a successor of this operation by `getEntrySuccessorRegions`, or the operands forwarded to the operation's results when it branches back to itself. These operands should correspond 1-1 with the successor inputs specified in `getEntrySuccessorRegions`. }], "::mlir::OperandRange", "getEntrySuccessorOperands", - (ins "::std::optional":$index), [{}], + (ins "::mlir::RegionBranchPoint":$point), [{}], /*defaultImplementation=*/[{ auto operandEnd = this->getOperation()->operand_end(); return ::mlir::OperandRange(operandEnd, operandEnd); @@ -162,22 +162,20 @@ (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands, "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}], [{ - $_op.getSuccessorRegions(std::nullopt, regions); + $_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions); }] >, InterfaceMethod<[{ - Returns the viable successors of a region at `index`, or the possible - successors when branching from the parent op if `index` is None. These - are the regions that may be selected during the flow of control. The - parent operation, i.e. a null `index`, may specify itself as successor, - which indicates that the control flow may not enter any region at all. - This method allows for describing which regions may be executed when - entering an operation, and which regions are executed after having - executed another region of the parent op. The successor region must be - non-empty. + Returns the viable successors of `point`. These are the regions that may + be selected during the flow of control. The parent operation, may + specify itself as successor, which indicates that the control flow may + not enter any region at all. This method allows for describing which + regions may be executed when entering an operation, and which regions + are executed after having executed another region of the parent op. The + successor region must be non-empty. }], "void", "getSuccessorRegions", - (ins "::std::optional":$index, + (ins "::mlir::RegionBranchPoint":$point, "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions) >, InterfaceMethod<[{ @@ -245,12 +243,10 @@ let methods = [ InterfaceMethod<[{ Returns a mutable range of operands that are semantically "returned" by - passing them to the region successor given by `index`. If `index` is - None, this function returns the operands that are passed as a result to - the parent operation. + passing them to the region successor given by `point`. }], "::mlir::MutableOperandRange", "getMutableSuccessorOperands", - (ins "::std::optional":$index) + (ins "::mlir::RegionBranchPoint":$point) >, InterfaceMethod<[{ Returns the viable region successors that are branched to after this @@ -269,8 +265,7 @@ [{ ::mlir::Operation *op = $_op; ::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp()) - .getSuccessorRegions(op->getParentRegion()->getRegionNumber(), - regions); + .getSuccessorRegions(op->getParentRegion(), regions); }] >, ]; @@ -290,8 +285,8 @@ // them to the region successor given by `index`. If `index` is None, this // function returns the operands that are passed as a result to the parent // operation. - ::mlir::OperandRange getSuccessorOperands(std::optional index) { - return getMutableSuccessorOperands(index); + ::mlir::OperandRange getSuccessorOperands(::mlir::RegionBranchPoint point) { + return getMutableSuccessorOperands(point); } }]; } @@ -309,7 +304,7 @@ /*extraOpDeclaration=*/"", /*extraOpDefinition=*/[{ ::mlir::MutableOperandRange $cppClass::getMutableSuccessorOperands( - ::std::optional index) { + ::mlir::RegionBranchPoint point) { return ::mlir::MutableOperandRange(*this); } }] diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -45,9 +45,9 @@ // this region predecessor that correspond to the input values of `region`. If // an index could not be found, std::nullopt is returned instead. auto getOperandIndexIfPred = - [&](std::optional predIndex) -> std::optional { + [&](RegionBranchPoint pred) -> std::optional { SmallVector successors; - branch.getSuccessorRegions(predIndex, successors); + branch.getSuccessorRegions(pred, successors); for (RegionSuccessor &successor : successors) { if (successor.getSuccessor() != region) continue; @@ -75,28 +75,27 @@ }; // Check branches from the parent operation. - std::optional regionIndex; - if (region) { - // Determine the actual region number from the passed region. - regionIndex = region->getRegionNumber(); - } + auto branchPoint = RegionBranchPoint::parent(); + if (region) + branchPoint = region; + if (std::optional operandIndex = - getOperandIndexIfPred(/*predIndex=*/std::nullopt)) { + getOperandIndexIfPred(/*predIndex=*/RegionBranchPoint::parent())) { collectUnderlyingAddressValues( - branch.getEntrySuccessorOperands(regionIndex)[*operandIndex], maxDepth, + branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth, visited, output); } // Check branches from each child region. Operation *op = branch.getOperation(); - for (int i = 0, e = op->getNumRegions(); i != e; ++i) { - if (std::optional operandIndex = getOperandIndexIfPred(i)) { - for (Block &block : op->getRegion(i)) { + for (Region ®ion : op->getRegions()) { + if (std::optional operandIndex = getOperandIndexIfPred(region)) { + for (Block &block : region) { // Try to determine possible region-branch successor operands for the // current region. if (auto term = dyn_cast( block.getTerminator())) { collectUnderlyingAddressValues( - term.getSuccessorOperands(regionIndex)[*operandIndex], maxDepth, + term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth, visited, output); } else if (block.getNumSuccessors()) { // Otherwise, if this terminator may exit the region we can't make diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -312,7 +312,8 @@ // Special cases where control flow may dictate data flow. if (auto branch = dyn_cast(op)) - return visitRegionBranchOperation(op, branch, std::nullopt, before); + return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(), + before); if (auto call = dyn_cast(op)) return visitCallOperation(call, before); @@ -368,8 +369,7 @@ // If this block is exiting from an operation with region-based control // flow, propagate the lattice back along the control flow edge. if (auto branch = dyn_cast(block->getParentOp())) { - visitRegionBranchOperation(block, branch, - block->getParent()->getRegionNumber(), before); + visitRegionBranchOperation(block, branch, block->getParent(), before); return; } @@ -396,13 +396,13 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation( ProgramPoint point, RegionBranchOpInterface branch, - std::optional regionNo, AbstractDenseLattice *before) { + RegionBranchPoint branchPoint, AbstractDenseLattice *before) { // The successors of the operation may be either the first operation of the // entry block of each possible successor region, or the next operation when // the branch is a successor of itself. SmallVector successors; - branch.getSuccessorRegions(regionNo, successors); + branch.getSuccessorRegions(branchPoint, successors); for (const RegionSuccessor &successor : successors) { const AbstractDenseLattice *after; if (successor.isParent() || successor.getSuccessor()->empty()) { @@ -423,10 +423,8 @@ else after = getLatticeFor(point, &successorBlock->front()); } - std::optional successorNo = - successor.isParent() ? std::optional() - : successor.getSuccessor()->getRegionNumber(); - visitRegionBranchControlFlowTransfer(branch, regionNo, successorNo, *after, + + visitRegionBranchControlFlowTransfer(branch, branchPoint, successor, *after, before); } } diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -99,7 +99,7 @@ // The results of a region branch operation are determined by control-flow. if (auto branch = dyn_cast(op)) { return visitRegionSuccessors({branch}, branch, - /*successorIndex=*/std::nullopt, + /*successor=*/RegionBranchPoint::parent(), resultLattices); } @@ -167,8 +167,8 @@ // Check if the lattices can be determined from region control flow. if (auto branch = dyn_cast(block->getParentOp())) { - return visitRegionSuccessors( - block, branch, block->getParent()->getRegionNumber(), argLattices); + return visitRegionSuccessors(block, branch, block->getParent(), + argLattices); } // Otherwise, we can't reason about the data-flow. @@ -212,8 +212,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( ProgramPoint point, RegionBranchOpInterface branch, - std::optional successorIndex, - ArrayRef lattices) { + RegionBranchPoint successor, ArrayRef lattices) { const auto *predecessors = getOrCreateFor(point, point); assert(predecessors->allPredecessorsKnown() && "unexpected unresolved region successors"); @@ -224,11 +223,11 @@ // Check if the predecessor is the parent op. if (op == branch) { - operands = branch.getEntrySuccessorOperands(successorIndex); + operands = branch.getEntrySuccessorOperands(successor); // Otherwise, try to deduce the operands from a region return-like op. } else if (auto regionTerminator = dyn_cast(op)) { - operands = regionTerminator.getSuccessorOperands(successorIndex); + operands = regionTerminator.getSuccessorOperands(successor); } if (!operands) { @@ -501,10 +500,7 @@ BitVector unaccounted(op->getNumOperands(), true); for (RegionSuccessor &successor : successors) { - Region *region = successor.getSuccessor(); - OperandRange operands = - region ? branch.getEntrySuccessorOperands(region->getRegionNumber()) - : branch.getEntrySuccessorOperands({}); + OperandRange operands = branch.getEntrySuccessorOperands(successor); MutableArrayRef opoperands = operandsToOpOperands(operands); ValueRange inputs = successor.getSuccessorInputs(); for (auto [operand, input] : llvm::zip(opoperands, inputs)) { @@ -538,9 +534,7 @@ for (const RegionSuccessor &successor : successors) { ValueRange inputs = successor.getSuccessorInputs(); - Region *region = successor.getSuccessor(); - OperandRange operands = terminator.getSuccessorOperands( - region ? region->getRegionNumber() : std::optional{}); + OperandRange operands = terminator.getSuccessorOperands(successor); MutableArrayRef opOperands = operandsToOpOperands(operands); for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) { meet(getLatticeElement(opOperand.get()), diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2379,9 +2379,9 @@ /// correspond to the loop iterator operands, i.e., those excluding the /// induction variable. AffineForOp only has one region, so zero is the only /// valid value for `index`. -OperandRange -AffineForOp::getEntrySuccessorOperands(std::optional index) { - assert((!index || *index == 0) && "invalid region index"); +OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { + assert((point.isParent() || point == getLoopBody()) && + "invalid region point"); // The initial operands map to the loop arguments after the induction // variable or are forwarded to the results when the trip count is zero. @@ -2394,14 +2394,15 @@ /// correspond to a constant value for each operand, or null if that operand is /// not a constant. void AffineForOp::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { - assert((!index.has_value() || index.value() == 0) && "expected loop region"); + RegionBranchPoint point, SmallVectorImpl ®ions) { + assert((point.isParent() || point == getLoopBody()) && + "expected loop region"); // The loop may typically branch back to its body or to the parent operation. // If the predecessor is the parent op and the trip count is known to be at // least one, branch into the body using the iterator arguments. And in cases // we know the trip count is zero, it can only branch back to its parent. std::optional tripCount = getTrivialConstantTripCount(*this); - if (!index.has_value() && tripCount.has_value()) { + if (point.isParent() && tripCount.has_value()) { if (tripCount.value() > 0) { regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs())); return; @@ -2414,7 +2415,7 @@ // From the loop body, if the trip count is one, we can only branch back to // the parent. - if (index && tripCount && *tripCount == 1) { + if (!point.isParent() && tripCount && *tripCount == 1) { regions.push_back(RegionSuccessor(getResults())); return; } @@ -2859,10 +2860,10 @@ /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp void AffineIfOp::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { + RegionBranchPoint point, SmallVectorImpl ®ions) { // If the predecessor is an AffineIfOp, then branching into both `then` and // `else` region is valid. - if (!index.has_value()) { + if (point.isParent()) { regions.reserve(2); regions.push_back( RegionSuccessor(&getThenRegion(), getThenRegion().getArguments())); diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -38,9 +38,8 @@ constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes"; -OperandRange -ExecuteOp::getEntrySuccessorOperands(std::optional index) { - assert(index && *index == 0 && "invalid region index"); +OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) { + assert(point == getBodyRegion() && "invalid region index"); return getBodyOperands(); } @@ -53,11 +52,10 @@ return getValueOrTokenType(lhs) == getValueOrTokenType(rhs); } -void ExecuteOp::getSuccessorRegions(std::optional index, +void ExecuteOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The `body` region branch back to the parent operation. - if (index) { - assert(*index == 0 && "invalid region index"); + if (point == getBodyRegion()) { regions.push_back(RegionSuccessor(getBodyResults())); return; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -372,7 +372,7 @@ // parent operation. In this case, we have to introduce an additional clone // for buffer that is passed to the argument. SmallVector successorRegions; - regionInterface.getSuccessorRegions(/*index=*/std::nullopt, + regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(), successorRegions); auto *it = llvm::find_if(successorRegions, [&](RegionSuccessor &successorRegion) { @@ -383,8 +383,7 @@ // Determine the actual operand to introduce a clone for and rewire the // operand to point to the clone instead. - auto operands = - regionInterface.getEntrySuccessorOperands(argRegion->getRegionNumber()); + auto operands = regionInterface.getEntrySuccessorOperands(argRegion); size_t operandIndex = llvm::find(it->getSuccessorInputs(), blockArg).getIndex() + operands.getBeginOperandIndex(); @@ -432,8 +431,7 @@ // Query the regionInterface to get all successor regions of the current // one. SmallVector successorRegions; - regionInterface.getSuccessorRegions(region.getRegionNumber(), - successorRegions); + regionInterface.getSuccessorRegions(region, successorRegions); // Try to find a matching region successor. RegionSuccessor *regionSuccessor = llvm::find_if(successorRegions, regionPredicate); @@ -445,10 +443,6 @@ llvm::find(regionSuccessor->getSuccessorInputs(), argValue) .getIndex(); - std::optional successorRegionNumber; - if (Region *successorRegion = regionSuccessor->getSuccessor()) - successorRegionNumber = successorRegion->getRegionNumber(); - // Iterate over all immediate terminator operations to introduce // new buffer allocations. Thereby, the appropriate terminator operand // will be adjusted to point to the newly allocated buffer instead. @@ -456,8 +450,7 @@ ®ion, [&](RegionBranchTerminatorOpInterface terminator) { // Get the actual mutable operands for this terminator op. auto terminatorOperands = - terminator.getMutableSuccessorOperands( - successorRegionNumber); + terminator.getMutableSuccessorOperands(*regionSuccessor); // Extract the source value from the current terminator. // This conversion needs to exist on a separate line due to a // bug in GCC conversion analysis. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp @@ -123,7 +123,7 @@ return true; // Recurses into all region successors. SmallVector successors; - regionInterface.getSuccessorRegions(current->getRegionNumber(), successors); + regionInterface.getSuccessorRegions(current, successors); for (RegionSuccessor ®ionEntry : successors) if (recurse(regionEntry.getSuccessor())) return true; @@ -132,7 +132,8 @@ // Start with all entry regions and test whether they induce a loop. SmallVector successorRegions; - regionInterface.getSuccessorRegions(/*index=*/std::nullopt, successorRegions); + regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(), + successorRegions); for (RegionSuccessor ®ionEntry : successorRegions) { if (recurse(regionEntry.getSuccessor())) return true; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp @@ -100,16 +100,13 @@ // Query the RegionBranchOpInterface to find potential successor regions. // Extract all entry regions and wire all initial entry successor inputs. SmallVector entrySuccessors; - regionInterface.getSuccessorRegions(/*index=*/std::nullopt, + regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(), entrySuccessors); for (RegionSuccessor &entrySuccessor : entrySuccessors) { // Wire the entry region's successor arguments with the initial // successor inputs. registerDependencies( - regionInterface.getEntrySuccessorOperands( - entrySuccessor.isParent() - ? std::optional() - : entrySuccessor.getSuccessor()->getRegionNumber()), + regionInterface.getEntrySuccessorOperands(entrySuccessor), entrySuccessor.getSuccessorInputs()); } @@ -118,21 +115,16 @@ // Iterate over all successor region entries that are reachable from the // current region. SmallVector successorRegions; - regionInterface.getSuccessorRegions(region.getRegionNumber(), - successorRegions); + regionInterface.getSuccessorRegions(region, successorRegions); for (RegionSuccessor &successorRegion : successorRegions) { - // Determine the current region index (if any). - std::optional regionIndex; - Region *regionSuccessor = successorRegion.getSuccessor(); - if (regionSuccessor) - regionIndex = regionSuccessor->getRegionNumber(); // Iterate over all immediate terminator operations and wire the // successor inputs with the successor operands of each terminator. for (Block &block : region) if (auto terminator = dyn_cast( block.getTerminator())) - registerDependencies(terminator.getSuccessorOperands(regionIndex), - successorRegion.getSuccessorInputs()); + registerDependencies( + terminator.getSuccessorOperands(successorRegion), + successorRegion.getSuccessorInputs()); } } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -455,8 +455,8 @@ } void AllocaScopeOp::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { - if (index) { + RegionBranchPoint point, SmallVectorImpl ®ions) { + if (!point.isParent()) { regions.push_back(RegionSuccessor(getResults())); return; } diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -266,9 +266,9 @@ /// correspond to a constant value for each operand, or null if that operand is /// not a constant. void ExecuteRegionOp::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { + RegionBranchPoint point, SmallVectorImpl ®ions) { // If the predecessor is the ExecuteRegionOp, branch into the body. - if (!index) { + if (point.isParent()) { regions.push_back(RegionSuccessor(&getRegion())); return; } @@ -282,8 +282,8 @@ //===----------------------------------------------------------------------===// MutableOperandRange -ConditionOp::getMutableSuccessorOperands(std::optional index) { - assert((!index || index == getParentOp().getAfter().getRegionNumber()) && +ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { + assert((point.isParent() || point == getParentOp().getAfter()) && "condition op can only exit the loop or branch to the after" "region"); // Pass all operands except the condition to the successor region. @@ -553,7 +553,7 @@ /// Return operands used when entering the region at 'index'. These operands /// correspond to the loop iterator operands, i.e., those excluding the /// induction variable. -OperandRange ForOp::getEntrySuccessorOperands(std::optional index) { +OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) { return getInitArgs(); } @@ -562,7 +562,7 @@ /// during the flow of control. `operands` is a set of optional attributes that /// correspond to a constant value for each operand, or null if that operand is /// not a constant. -void ForOp::getSuccessorRegions(std::optional index, +void ForOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // Both the operation itself and the region may be branching into the body or // back into the operation itself. It is possible for loop not to enter the @@ -1731,7 +1731,7 @@ /// during the flow of control. `operands` is a set of optional attributes that /// correspond to a constant value for each operand, or null if that operand is /// not a constant. -void ForallOp::getSuccessorRegions(std::optional index, +void ForallOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // Both the operation itself and the region may be branching into the body or // back into the operation itself. It is possible for loop not to enter the @@ -2011,10 +2011,10 @@ /// during the flow of control. `operands` is a set of optional attributes that /// correspond to a constant value for each operand, or null if that operand is /// not a constant. -void IfOp::getSuccessorRegions(std::optional index, +void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. - if (index) { + if (!point.isParent()) { regions.push_back(RegionSuccessor(getResults())); return; } @@ -3042,7 +3042,7 @@ /// correspond to a constant value for each operand, or null if that operand is /// not a constant. void ParallelOp::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { + RegionBranchPoint point, SmallVectorImpl ®ions) { // Both the operation itself and the region may be branching into the body or // back into the operation itself. It is possible for loop not to enter the // body. @@ -3169,8 +3169,8 @@ afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments()); } -OperandRange WhileOp::getEntrySuccessorOperands(std::optional index) { - assert(index && *index == 0 && +OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) { + assert(point == getBefore() && "WhileOp is expected to branch only to the first region"); return getInits(); @@ -3192,17 +3192,18 @@ return getAfterBody()->getArguments(); } -void WhileOp::getSuccessorRegions(std::optional index, +void WhileOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The parent op always branches to the condition region. - if (!index) { + if (point.isParent()) { regions.emplace_back(&getBefore(), getBefore().getArguments()); return; } - assert(*index < 2 && "there are only two regions in a WhileOp"); + assert(llvm::is_contained({&getAfter(), &getBefore()}, point) && + "there are only two regions in a WhileOp"); // The body region always branches back to the condition region. - if (*index == 1) { + if (point == getAfter()) { regions.emplace_back(&getBefore(), getBefore().getArguments()); return; } @@ -4023,10 +4024,9 @@ } void IndexSwitchOp::getSuccessorRegions( - std::optional index, - SmallVectorImpl &successors) { + RegionBranchPoint point, SmallVectorImpl &successors) { // All regions branch back to the parent op. - if (index) { + if (!point.isParent()) { successors.emplace_back(getResults()); return; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -335,11 +335,11 @@ // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td void AssumingOp::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { + RegionBranchPoint point, SmallVectorImpl ®ions) { // AssumingOp has unconditional control flow into the region and back to the // parent, so return the correct RegionSuccessor purely based on the index // being None or 0. - if (index) { + if (!point.isParent()) { regions.push_back(RegionSuccessor(getResults())); return; } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -86,23 +86,25 @@ // AlternativesOp //===----------------------------------------------------------------------===// -OperandRange transform::AlternativesOp::getEntrySuccessorOperands( - std::optional index) { - if (index && getOperation()->getNumOperands() == 1) +OperandRange +transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { + if (!point.isParent() && getOperation()->getNumOperands() == 1) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), getOperation()->operand_end()); } void transform::AlternativesOp::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { + RegionBranchPoint point, SmallVectorImpl ®ions) { for (Region &alternative : llvm::drop_begin( - getAlternatives(), index.has_value() ? *index + 1 : 0)) { + getAlternatives(), + point.isParent() ? 0 + : point.getRegionOrNull()->getRegionNumber() + 1)) { regions.emplace_back(&alternative, !getOperands().empty() ? alternative.getArguments() : Block::BlockArgListType()); } - if (index.has_value()) + if (!point.isParent()) regions.emplace_back(getOperation()->getResults()); } @@ -1159,24 +1161,24 @@ } void transform::ForeachOp::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { + RegionBranchPoint point, SmallVectorImpl ®ions) { Region *bodyRegion = &getBody(); - if (!index) { + if (point.isParent()) { regions.emplace_back(bodyRegion, bodyRegion->getArguments()); return; } // Branch back to the region or the parent. - assert(*index == 0 && "unexpected region index"); + assert(point == getBody() && "unexpected region index"); regions.emplace_back(bodyRegion, bodyRegion->getArguments()); regions.emplace_back(); } OperandRange -transform::ForeachOp::getEntrySuccessorOperands(std::optional index) { +transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) { // The iteration variable op handle is mapped to a subset (one op to be // precise) of the payload ops of the ForeachOp operand. - assert(index && *index == 0 && "unexpected region index"); + assert(point == getBody() && "unexpected region index"); return getOperation()->getOperands(); } @@ -2178,9 +2180,9 @@ getPotentialTopLevelEffects(effects); } -OperandRange transform::SequenceOp::getEntrySuccessorOperands( - std::optional index) { - assert(index && *index == 0 && "unexpected region index"); +OperandRange +transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) { + assert(point == getBody() && "unexpected region index"); if (getOperation()->getNumOperands() > 0) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), @@ -2188,8 +2190,8 @@ } void transform::SequenceOp::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { - if (!index) { + RegionBranchPoint point, SmallVectorImpl ®ions) { + if (point.isParent()) { Region *bodyRegion = &getBody(); regions.emplace_back(bodyRegion, getNumOperands() != 0 ? bodyRegion->getArguments() @@ -2197,7 +2199,7 @@ return; } - assert(*index == 0 && "unexpected region index"); + assert(point == getBody() && "unexpected region index"); regions.emplace_back(getOperation()->getResults()); } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5821,8 +5821,8 @@ } void WarpExecuteOnLane0Op::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { - if (index) { + RegionBranchPoint point, SmallVectorImpl ®ions) { + if (!point.isParent()) { regions.push_back(RegionSuccessor(getResults())); return; } diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -84,18 +84,18 @@ // RegionBranchOpInterface //===----------------------------------------------------------------------===// -static InFlightDiagnostic & -printRegionEdgeName(InFlightDiagnostic &diag, std::optional sourceNo, - std::optional succRegionNo) { +static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag, + RegionBranchPoint sourceNo, + RegionBranchPoint succRegionNo) { diag << "from "; - if (sourceNo) - diag << "Region #" << sourceNo.value(); + if (Region *region = sourceNo.getRegionOrNull()) + diag << "Region #" << region->getRegionNumber(); else diag << "parent operands"; diag << " to "; - if (succRegionNo) - diag << "Region #" << succRegionNo.value(); + if (Region *region = succRegionNo.getRegionOrNull()) + diag << "Region #" << region->getRegionNumber(); else diag << "parent results"; return diag; @@ -107,28 +107,24 @@ /// inputs that flow from `sourceIndex' to the given region, or std::nullopt if /// the exact type match verification is not necessary (e.g., if the Op verifies /// the match itself). -static LogicalResult verifyTypesAlongAllEdges( - Operation *op, std::optional sourceNo, - function_ref(std::optional)> - getInputsTypesForRegion) { +static LogicalResult +verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, + function_ref(RegionBranchPoint)> + getInputsTypesForRegion) { auto regionInterface = cast(op); SmallVector successors; - regionInterface.getSuccessorRegions(sourceNo, successors); + regionInterface.getSuccessorRegions(sourcePoint, successors); for (RegionSuccessor &succ : successors) { - std::optional succRegionNo; - if (!succ.isParent()) - succRegionNo = succ.getSuccessor()->getRegionNumber(); - - FailureOr sourceTypes = getInputsTypesForRegion(succRegionNo); + FailureOr sourceTypes = getInputsTypesForRegion(succ); if (failed(sourceTypes)) return failure(); TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); if (sourceTypes->size() != succInputsTypes.size()) { InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); - return printRegionEdgeName(diag, sourceNo, succRegionNo) + return printRegionEdgeName(diag, sourcePoint, succ) << ": source has " << sourceTypes->size() << " operands, but target successor needs " << succInputsTypes.size(); @@ -140,7 +136,7 @@ Type inputType = std::get<1>(typesIdx.value()); if (!regionInterface.areTypesCompatible(sourceType, inputType)) { InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); - return printRegionEdgeName(diag, sourceNo, succRegionNo) + return printRegionEdgeName(diag, sourcePoint, succ) << ": source type #" << typesIdx.index() << " " << sourceType << " should match input type #" << typesIdx.index() << " " << inputType; @@ -154,13 +150,13 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { auto regionInterface = cast(op); - auto inputTypesFromParent = - [&](std::optional regionNo) -> TypeRange { + auto inputTypesFromParent = [&](RegionBranchPoint regionNo) -> TypeRange { return regionInterface.getEntrySuccessorOperands(regionNo).getTypes(); }; // Verify types along control flow edges originating from the parent. - if (failed(verifyTypesAlongAllEdges(op, std::nullopt, inputTypesFromParent))) + if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(), + inputTypesFromParent))) return failure(); auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) { @@ -176,8 +172,7 @@ }; // Verify types along control flow edges originating from each region. - for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { - Region ®ion = op->getRegion(regionNo); + for (Region ®ion : op->getRegions()) { // Since there can be multiple terminators implementing the // `RegionBranchTerminatorOpInterface`, all should have the same operand @@ -195,7 +190,7 @@ continue; auto inputTypesForRegion = - [&](std::optional succRegionNo) -> FailureOr { + [&](RegionBranchPoint succRegionNo) -> FailureOr { std::optional regionReturnOperands; for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) { auto terminatorOperands = @@ -211,7 +206,7 @@ if (!areTypesCompatible(regionReturnOperands->getTypes(), terminatorOperands.getTypes())) { InFlightDiagnostic diag = op->emitOpError(" along control flow edge"); - return printRegionEdgeName(diag, regionNo, succRegionNo) + return printRegionEdgeName(diag, region, succRegionNo) << " operands mismatch between return-like terminators"; } } @@ -220,7 +215,7 @@ return TypeRange(regionReturnOperands->getTypes()); }; - if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesForRegion))) + if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion))) return failure(); } @@ -237,24 +232,24 @@ visited[begin->getRegionNumber()] = true; // Retrieve all successors of the region and enqueue them in the worklist. - SmallVector worklist; - auto enqueueAllSuccessors = [&](unsigned index) { + SmallVector worklist; + auto enqueueAllSuccessors = [&](Region *region) { SmallVector successors; - op.getSuccessorRegions(index, successors); + op.getSuccessorRegions(region, successors); for (RegionSuccessor successor : successors) if (!successor.isParent()) - worklist.push_back(successor.getSuccessor()->getRegionNumber()); + worklist.push_back(successor.getSuccessor()); }; - enqueueAllSuccessors(begin->getRegionNumber()); + enqueueAllSuccessors(begin); // Process all regions in the worklist via DFS. while (!worklist.empty()) { - unsigned nextRegion = worklist.pop_back_val(); - if (nextRegion == r->getRegionNumber()) + Region *nextRegion = worklist.pop_back_val(); + if (nextRegion == r) return true; - if (visited[nextRegion]) + if (visited[nextRegion->getRegionNumber()]) continue; - visited[nextRegion] = true; + visited[nextRegion->getRegionNumber()] = true; enqueueAllSuccessors(nextRegion); } diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -316,15 +316,11 @@ // Return the successors of `region` if the latter is not null. Else return // the successors of `regionBranchOp`. auto getSuccessors = [&](Region *region = nullptr) { - std::optional index = - region ? std::optional(region->getRegionNumber()) : std::nullopt; + auto point = region ? region : RegionBranchPoint::parent(); SmallVector operandAttributes(regionBranchOp->getNumOperands(), nullptr); SmallVector successors; - if (!index) - regionBranchOp.getEntrySuccessorRegions(operandAttributes, successors); - else - regionBranchOp.getSuccessorRegions(index, successors); + regionBranchOp.getSuccessorRegions(point, successors); return successors; }; @@ -333,14 +329,10 @@ // forwarded to `successor`. auto getForwardedOpOperands = [&](const RegionSuccessor &successor, Operation *terminator = nullptr) { - Region *successorRegion = successor.getSuccessor(); - std::optional index = - successorRegion ? std::optional(successorRegion->getRegionNumber()) - : std::nullopt; OperandRange operands = terminator ? cast(terminator) - .getSuccessorOperands(index) - : regionBranchOp.getEntrySuccessorOperands(index); + .getSuccessorOperands(successor) + : regionBranchOp.getEntrySuccessorOperands(successor); SmallVector opOperands = operandsToOpOperands(operands); return opOperands; }; diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp --- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp @@ -60,8 +60,8 @@ NextAccess *before) override; void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, - std::optional regionFrom, - std::optional regionTo, + RegionBranchPoint regionFrom, + RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) override; @@ -124,15 +124,15 @@ } void NextAccessAnalysis::visitRegionBranchControlFlowTransfer( - RegionBranchOpInterface branch, std::optional regionFrom, - std::optional regionTo, const NextAccess &after, - NextAccess *before) { + RegionBranchOpInterface branch, RegionBranchPoint regionFrom, + RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) { auto testStoreWithARegion = dyn_cast<::test::TestStoreWithARegion>(branch.getOperation()); if (testStoreWithARegion && - ((!regionTo && !testStoreWithARegion.getStoreBeforeRegion()) || - (!regionFrom && testStoreWithARegion.getStoreBeforeRegion()))) { + ((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) || + (regionFrom.isParent() && + testStoreWithARegion.getStoreBeforeRegion()))) { visitOperation(branch, static_cast(after), static_cast(before)); } else { @@ -219,7 +219,7 @@ SmallVector entryPointNextAccess; SmallVector regionSuccessors; - iface.getSuccessorRegions(std::nullopt, regionSuccessors); + iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors); for (const RegionSuccessor &successor : regionSuccessors) { if (!successor.getSuccessor() || successor.getSuccessor()->empty()) continue; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -931,17 +931,17 @@ parser.getCurrentLocation(), result.operands); } -OperandRange -RegionIfOp::getEntrySuccessorOperands(std::optional index) { - assert(index && *index < 2 && "invalid region index"); +OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) { + assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) && + "invalid region index"); return getOperands(); } void RegionIfOp::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { + RegionBranchPoint point, SmallVectorImpl ®ions) { // We always branch to the join region. - if (index.has_value()) { - if (index.value() < 2) + if (!point.isParent()) { + if (point != getJoinRegion()) regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); else regions.push_back(RegionSuccessor(getResults())); @@ -964,11 +964,11 @@ // AnyCondOp //===----------------------------------------------------------------------===// -void AnyCondOp::getSuccessorRegions(std::optional index, +void AnyCondOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The parent op branches into the only region, and the region branches back // to the parent op. - if (!index) + if (point.isParent()) regions.emplace_back(&getRegion()); else regions.emplace_back(getResults()); @@ -985,17 +985,16 @@ //===----------------------------------------------------------------------===// void LoopBlockOp::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { + RegionBranchPoint point, SmallVectorImpl ®ions) { regions.emplace_back(&getBody(), getBody().getArguments()); - if (!index) + if (point.isParent()) return; regions.emplace_back((*this)->getResults()); } -OperandRange -LoopBlockOp::getEntrySuccessorOperands(std::optional index) { - assert(index == 0); +OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { + assert(point == getBody()); return getInitMutable(); } @@ -1003,10 +1002,9 @@ // LoopBlockTerminatorOp //===----------------------------------------------------------------------===// -MutableOperandRange LoopBlockTerminatorOp::getMutableSuccessorOperands( - std::optional index) { - assert(!index || index == 0); - if (!index) +MutableOperandRange +LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) { + if (point.isParent()) return getExitArgMutable(); return getNextIterArgMutable(); } @@ -1313,12 +1311,11 @@ } void TestStoreWithARegion::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { - if (!index) { + RegionBranchPoint point, SmallVectorImpl ®ions) { + if (point.isParent()) regions.emplace_back(&getBody(), getBody().front().getArguments()); - } else { + else regions.emplace_back(); - } } LogicalResult diff --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp --- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp @@ -37,7 +37,7 @@ } // Regions have no successors. - void getSuccessorRegions(std::optional index, + void getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) {} }; @@ -51,14 +51,13 @@ static StringRef getOperationName() { return "cftest.loop_regions_op"; } - void getSuccessorRegions(std::optional index, + void getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { - if (index) { - if (*index == 1) + if (Region *region = point.getRegionOrNull()) { + if (point == (*this)->getRegion(1)) // This region also branches back to the parent. regions.push_back(RegionSuccessor()); - regions.push_back( - RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions))); + regions.push_back(RegionSuccessor(region)); } } }; @@ -74,11 +73,11 @@ return "cftest.double_loop_regions_op"; } - void getSuccessorRegions(std::optional index, + void getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { - if (index.has_value()) { + if (Region *region = point.getRegionOrNull()) { regions.push_back(RegionSuccessor()); - regions.push_back(RegionSuccessor(&getOperation()->getRegion(*index))); + regions.push_back(RegionSuccessor(region)); } } }; @@ -92,9 +91,9 @@ static StringRef getOperationName() { return "cftest.sequential_regions_op"; } // Region 0 has Region 1 as a successor. - void getSuccessorRegions(std::optional index, + void getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { - if (index == 0u) { + if (point == (*this)->getRegion(0)) { Operation *thisOp = this->getOperation(); regions.push_back(RegionSuccessor(&thisOp->getRegion(1))); }