diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2147,7 +2147,7 @@ } def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods, RecursiveMemoryEffects, + "getRegionInvocationBounds", "getEntrySuccessorRegions"]>, RecursiveMemoryEffects, NoRegionArguments]> { let summary = "if-then-else conditional operation"; let description = [{ diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -3461,15 +3461,13 @@ } } -// These 2 functions copied from scf.if implementation. +// These 3 functions copied from scf.if implementation. /// Given the region at `index`, or the parent operation if `index` is None, /// return the successor regions. These are the regions that may be selected -/// during the flow of control. `operands` is a set of optional attributes that -/// correspond to a constant value for each operand, or null if that operand is -/// not a constant. +/// during the flow of control. void fir::IfOp::getSuccessorRegions( - std::optional index, llvm::ArrayRef operands, + std::optional index, llvm::SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. if (index) { @@ -3477,27 +3475,33 @@ return; } + // Don't consider the else region if it is empty. + regions.push_back(mlir::RegionSuccessor(&getThenRegion())); + // Don't consider the else region if it is empty. mlir::Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - elseRegion = nullptr; + regions.push_back(mlir::RegionSuccessor()); + else + regions.push_back(mlir::RegionSuccessor(elseRegion)); +} - // Otherwise, the successor is dependent on the condition. - bool condition; - if (auto condAttr = operands.front().dyn_cast_or_null()) { - condition = condAttr.getValue().isOne(); - } else { - // If the condition isn't constant, both regions may be executed. - regions.push_back(mlir::RegionSuccessor(&getThenRegion())); - // If the else region does not exist, it is not a viable successor. - if (elseRegion) - regions.push_back(mlir::RegionSuccessor(elseRegion)); - return; +void fir::IfOp::getEntrySuccessorRegions( + llvm::ArrayRef operands, + llvm::SmallVectorImpl ®ions) { + FoldAdaptor adaptor(operands); + auto boolAttr = + mlir::dyn_cast_or_null(adaptor.getCondition()); + if (!boolAttr || boolAttr.getValue()) + regions.emplace_back(&getThenRegion()); + + // If the else region is empty, execution continues after the parent op. + if (!boolAttr || !boolAttr.getValue()) { + if (!getElseRegion().empty()) + regions.emplace_back(&getElseRegion()); + else + regions.emplace_back(getResults()); } - - // Add the successor regions using the condition. - regions.push_back( - mlir::RegionSuccessor(condition ? &getThenRegion() : elseRegion)); } void fir::IfOp::getRegionInvocationBounds( diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h @@ -26,6 +26,7 @@ class CallableOpInterface; class BranchOpInterface; class RegionBranchOpInterface; +class RegionBranchTerminatorOpInterface; namespace dataflow { @@ -207,7 +208,8 @@ /// Visit the given terminator operation that exits a region under an /// operation with control-flow semantics. These are terminators with no CFG /// successors. - void visitRegionTerminator(Operation *op, RegionBranchOpInterface branch); + void visitRegionTerminator(RegionBranchTerminatorOpInterface op, + RegionBranchOpInterface branch); /// Visit the given terminator operation that exits a callable region. These /// are terminators with no CFG successors. diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -123,7 +123,7 @@ ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep", "getSingleUpperBound"]>, DeclareOpInterfaceMethods]> { + ["getEntrySuccessorOperands"]>]> { let summary = "for operation"; let description = [{ Syntax: diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -35,7 +35,7 @@ def Async_ExecuteOp : Async_Op<"execute", [SingleBlockImplicitTerminator<"YieldOp">, DeclareOpInterfaceMethods, AttrSizedOperandSegments, AutomaticAllocationScope]> { @@ -312,8 +312,7 @@ def Async_YieldOp : Async_Op<"yield", [ - HasParent<"ExecuteOp">, Pure, Terminator, - DeclareOpInterfaceMethods]> { + HasParent<"ExecuteOp">, Pure, Terminator, ReturnLike]> { let summary = "terminator for Async execute operation"; let description = [{ The `async.yield` is a special terminator operation for the block inside @@ -322,7 +321,6 @@ let arguments = (ins Variadic:$operands); let assemblyFormat = "($operands^ `:` type($operands))? attr-dict"; - let hasVerifier = 1; } def Async_AwaitOp : Async_Op<"await"> { diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -39,7 +39,8 @@ def ConditionOp : SCF_Op<"condition", [ HasParent<"WhileOp">, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, Pure, Terminator ]> { @@ -124,7 +125,8 @@ "getSingleUpperBound", "promoteIfSingleIteration"]>, AllTypesMatch<["lowerBound", "upperBound", "step"]>, ConditionallySpeculatable, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects]> { let summary = "for operation"; @@ -335,12 +337,6 @@ getNumControlOperands() + opResult.getResultNumber()); } - /// Return operands used when entering the region at 'index'. These operands - /// correspond to the loop iterator operands, i.e., those exclusing the - /// induction variable. LoopOp only has one region, so 0 is the only valid - /// value for `index`. - OperandRange getSuccessorEntryOperands(std::optional index); - /// Returns the step as an `APInt` if it is constant. std::optional getConstantStep(); @@ -712,7 +708,8 @@ //===----------------------------------------------------------------------===// def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods, + "getNumRegionInvocations", "getRegionInvocationBounds", + "getEntrySuccessorRegions"]>, InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects, NoRegionArguments]> { let summary = "if-then-else operation"; @@ -978,7 +975,8 @@ //===----------------------------------------------------------------------===// def WhileOp : SCF_Op<"while", - [DeclareOpInterfaceMethods, + [DeclareOpInterfaceMethods, RecursiveMemoryEffects]> { let summary = "a generic 'while' loop"; let description = [{ @@ -1108,7 +1106,6 @@ using BodyBuilderFn = function_ref; - OperandRange getSuccessorEntryOperands(std::optional index); ConditionOp getConditionOp(); YieldOp getYieldOp(); Block::BlockArgListType getBeforeArguments(); @@ -1127,7 +1124,8 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"scf::YieldOp">, DeclareOpInterfaceMethods]> { + ["getRegionInvocationBounds", + "getEntrySuccessorRegions"]>]> { let summary = "switch-case operation on an index argument"; let description = [{ The `scf.index_switch` is a control-flow operation that branches to one of diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -25,7 +25,7 @@ def AlternativesOp : TransformDialectOp<"alternatives", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, @@ -507,7 +507,7 @@ [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + "getSuccessorRegions", "getEntrySuccessorOperands"]>, SingleBlockImplicitTerminator<"::mlir::transform::YieldOp"> ]> { let summary = "Executes the body for each payload op"; @@ -1016,7 +1016,7 @@ def SequenceOp : TransformDialectOp<"sequence", [DeclareOpInterfaceMethods, MatchOpInterface, DeclareOpInterfaceMethods, diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -134,36 +134,50 @@ InterfaceMethod<[{ Returns the operands of this operation used as the entry arguments when entering the region at `index`, which was specified as a successor of - this operation by `getSuccessorRegions`, or the operands forwarded to - the operation's results when it branches back to itself. These operands + this operation by `getEntrySuccessorRegions`, or the operands forwarded + to the operation's results when it branches back to itself. These operands should correspond 1-1 with the successor inputs specified in - `getSuccessorRegions`. + `getEntrySuccessorRegions`. }], - "::mlir::OperandRange", "getSuccessorEntryOperands", + "::mlir::OperandRange", "getEntrySuccessorOperands", (ins "::std::optional":$index), [{}], /*defaultImplementation=*/[{ auto operandEnd = this->getOperation()->operand_end(); return ::mlir::OperandRange(operandEnd, operandEnd); }] >, + InterfaceMethod<[{ + Returns the viable region successors that are branched to when first + executing the op. + Unlike `getSuccessorRegions`, this method also passes along the + constant operands of this op. Based on these, different region + successors can be determined. + `operands` contains an entry for every operand of the implementing + op with a null attribute if the operand has no constant value or + the corresponding attribute if it is a constant. + + By default, simply dispatches to `getSuccessorRegions`. + }], + "void", "getEntrySuccessorRegions", + (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands, + "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), + [{}], [{ + $_op.getSuccessorRegions(std::nullopt, regions); + }] + >, InterfaceMethod<[{ Returns the viable successors of a region at `index`, or the possible successors when branching from the parent op if `index` is None. These - are the regions that may be selected during the flow of control. If - `index` is None, `operands` is a set of optional attributes that - either correspond to a constant value for each operand of this - operation, or null if that operand is not a constant. If `index` is - valid, `operands` corresponds to the entry values of the region at - `index`. The parent operation, i.e. a null `index`, may specify itself - as successor, which indicates that the control flow may not enter any - region at all. This method allows for describing which - regions may be executed when entering an operation, and which regions - are executed after having executed another region of the parent op. The - successor region must be non-empty. + are the regions that may be selected during the flow of control. The + parent operation, i.e. a null `index`, may specify itself as successor, + which indicates that the control flow may not enter any region at all. + This method allows for describing which regions may be executed when + entering an operation, and which regions are executed after having + executed another region of the parent op. The successor region must be + non-empty. }], "void", "getSuccessorRegions", (ins "::std::optional":$index, - "::llvm::ArrayRef<::mlir::Attribute>":$operands, "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions) >, InterfaceMethod<[{ @@ -208,10 +222,6 @@ let verifyWithRegions = 1; let extraClassDeclaration = [{ - /// Convenience helper in case none of the operands is known. - void getSuccessorRegions(std::optional index, - SmallVectorImpl ®ions); - /// Return `true` if control flow originating from the given region may /// eventually branch back to the same region. (Maybe after passing through /// other regions.) @@ -243,17 +253,26 @@ (ins "::std::optional":$index) >, InterfaceMethod<[{ - Returns a range of operands that are semantically "returned" by passing - them to the region successor given by `index`. If `index` is None, this - function returns the operands that are passed as a result to the parent - operation. + Returns the viable region successors that are branched to after this + terminator based on the given constant operands. + + `operands` contains an entry for every operand of the + implementing op with a null attribute if the operand has no constant + value or the corresponding attribute if it is a constant. + + Default implementation simply dispatches to the parent + `RegionBranchOpInterface`'s `getSuccessorRegions` implementation. }], - "::mlir::OperandRange", "getSuccessorOperands", - (ins "::std::optional":$index), [{}], - /*defaultImplementation=*/[{ - return $_op.getMutableSuccessorOperands(index); + "void", "getSuccessorRegions", + (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands, + "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}], + [{ + ::mlir::Operation *op = $_op; + ::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp()) + .getSuccessorRegions(op->getParentRegion()->getRegionNumber(), + regions); }] - > + >, ]; let verify = [{ @@ -265,6 +284,16 @@ "expected operation to have zero successors"); return success(); }]; + + let extraClassDeclaration = [{ + // Returns a range of operands that are semantically "returned" by passing + // them to the region successor given by `index`. If `index` is None, this + // function returns the operands that are passed as a result to the parent + // operation. + ::mlir::OperandRange getSuccessorOperands(std::optional index) { + return getMutableSuccessorOperands(index); + } + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -83,7 +83,7 @@ if (std::optional operandIndex = getOperandIndexIfPred(/*predIndex=*/std::nullopt)) { collectUnderlyingAddressValues( - branch.getSuccessorEntryOperands(regionIndex)[*operandIndex], maxDepth, + branch.getEntrySuccessorOperands(regionIndex)[*operandIndex], maxDepth, visited, output); } // Check branches from each child region. diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -259,7 +259,8 @@ if (isRegionOrCallableReturn(op)) { if (auto branch = dyn_cast(op->getParentOp())) { // Visit the exiting terminator of a region. - visitRegionTerminator(op, branch); + visitRegionTerminator(cast(op), + branch); } else if (auto callable = dyn_cast(op->getParentOp())) { // Visit the exiting terminator of a callable. @@ -361,7 +362,7 @@ return; SmallVector successors; - branch.getSuccessorRegions(/*index=*/{}, *operands, successors); + branch.getEntrySuccessorRegions(*operands, successors); for (const RegionSuccessor &successor : successors) { // The successor can be either an entry block or the parent operation. ProgramPoint point = successor.getSuccessor() @@ -378,15 +379,14 @@ } } -void DeadCodeAnalysis::visitRegionTerminator(Operation *op, - RegionBranchOpInterface branch) { +void DeadCodeAnalysis::visitRegionTerminator( + RegionBranchTerminatorOpInterface op, RegionBranchOpInterface branch) { std::optional> operands = getOperandValues(op); if (!operands) return; SmallVector successors; - branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(), - *operands, successors); + op.getSuccessorRegions(*operands, successors); // Mark successor region entry blocks as executable and add this op to the // list of predecessors. diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -224,7 +224,7 @@ // Check if the predecessor is the parent op. if (op == branch) { - operands = branch.getSuccessorEntryOperands(successorIndex); + operands = branch.getEntrySuccessorOperands(successorIndex); // Otherwise, try to deduce the operands from a region return-like op. } else if (auto regionTerminator = dyn_cast(op)) { @@ -479,7 +479,7 @@ Operation *op = branch.getOperation(); SmallVector successors; SmallVector operands(op->getNumOperands(), nullptr); - branch.getSuccessorRegions(/*index=*/{}, operands, successors); + branch.getEntrySuccessorRegions(operands, successors); // All operands not forwarded to any successor. This set can be non-contiguous // in the presence of multiple successors. @@ -488,8 +488,8 @@ for (RegionSuccessor &successor : successors) { Region *region = successor.getSuccessor(); OperandRange operands = - region ? branch.getSuccessorEntryOperands(region->getRegionNumber()) - : branch.getSuccessorEntryOperands({}); + region ? branch.getEntrySuccessorOperands(region->getRegionNumber()) + : branch.getEntrySuccessorOperands({}); MutableArrayRef opoperands = operandsToOpOperands(operands); ValueRange inputs = successor.getSuccessorInputs(); for (auto [operand, input] : llvm::zip(opoperands, inputs)) { @@ -516,8 +516,7 @@ SmallVector operandAttributes(terminator->getNumOperands(), nullptr); SmallVector successors; - branch.getSuccessorRegions(terminator->getParentRegion()->getRegionNumber(), - operandAttributes, successors); + terminator.getSuccessorRegions(operandAttributes, successors); // All operands not forwarded to any successor. This set can be // non-contiguous in the presence of multiple successors. BitVector unaccounted(terminator->getNumOperands(), true); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2380,7 +2380,7 @@ /// induction variable. AffineForOp only has one region, so zero is the only /// valid value for `index`. OperandRange -AffineForOp::getSuccessorEntryOperands(std::optional index) { +AffineForOp::getEntrySuccessorOperands(std::optional index) { assert((!index || *index == 0) && "invalid region index"); // The initial operands map to the loop arguments after the induction @@ -2394,8 +2394,7 @@ /// correspond to a constant value for each operand, or null if that operand is /// not a constant. void AffineForOp::getSuccessorRegions( - std::optional index, ArrayRef operands, - SmallVectorImpl ®ions) { + std::optional index, SmallVectorImpl ®ions) { assert((!index.has_value() || index.value() == 0) && "expected loop region"); // The loop may typically branch back to its body or to the parent operation. // If the predecessor is the parent op and the trip count is known to be at @@ -2860,8 +2859,7 @@ /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp void AffineIfOp::getSuccessorRegions( - std::optional index, ArrayRef operands, - SmallVectorImpl ®ions) { + std::optional index, SmallVectorImpl ®ions) { // If the predecessor is an AffineIfOp, then branching into both `then` and // `else` region is valid. if (!index.has_value()) { diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -32,31 +32,6 @@ >(); } -//===----------------------------------------------------------------------===// -// YieldOp -//===----------------------------------------------------------------------===// - -LogicalResult YieldOp::verify() { - // Get the underlying value types from async values returned from the - // parent `async.execute` operation. - auto executeOp = (*this)->getParentOfType(); - auto types = - llvm::map_range(executeOp.getBodyResults(), [](const OpResult &result) { - return llvm::cast(result.getType()).getValueType(); - }); - - if (getOperandTypes() != types) - return emitOpError("operand types do not match the types returned from " - "the parent ExecuteOp"); - - return success(); -} - -MutableOperandRange -YieldOp::getMutableSuccessorOperands(std::optional index) { - return getOperandsMutable(); -} - //===----------------------------------------------------------------------===// /// ExecuteOp //===----------------------------------------------------------------------===// @@ -64,7 +39,7 @@ constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes"; OperandRange -ExecuteOp::getSuccessorEntryOperands(std::optional index) { +ExecuteOp::getEntrySuccessorOperands(std::optional index) { assert(index && *index == 0 && "invalid region index"); return getBodyOperands(); } @@ -79,7 +54,6 @@ } void ExecuteOp::getSuccessorRegions(std::optional index, - ArrayRef, SmallVectorImpl ®ions) { // The `body` region branch back to the parent operation. if (index) { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -384,7 +384,7 @@ // Determine the actual operand to introduce a clone for and rewire the // operand to point to the clone instead. auto operands = - regionInterface.getSuccessorEntryOperands(argRegion->getRegionNumber()); + regionInterface.getEntrySuccessorOperands(argRegion->getRegionNumber()); size_t operandIndex = llvm::find(it->getSuccessorInputs(), blockArg).getIndex() + operands.getBeginOperandIndex(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp @@ -106,7 +106,7 @@ // Wire the entry region's successor arguments with the initial // successor inputs. registerDependencies( - regionInterface.getSuccessorEntryOperands( + regionInterface.getEntrySuccessorOperands( entrySuccessor.isParent() ? std::optional() : entrySuccessor.getSuccessor()->getRegionNumber()), diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -461,8 +461,7 @@ } void AllocaScopeOp::getSuccessorRegions( - std::optional index, ArrayRef operands, - SmallVectorImpl ®ions) { + std::optional index, SmallVectorImpl ®ions) { if (index) { regions.push_back(RegionSuccessor(getResults())); return; diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -266,8 +266,7 @@ /// correspond to a constant value for each operand, or null if that operand is /// not a constant. void ExecuteRegionOp::getSuccessorRegions( - std::optional index, ArrayRef operands, - SmallVectorImpl ®ions) { + std::optional index, SmallVectorImpl ®ions) { // If the predecessor is the ExecuteRegionOp, branch into the body. if (!index) { regions.push_back(RegionSuccessor(&getRegion())); @@ -288,6 +287,22 @@ return getArgsMutable(); } +void ConditionOp::getSuccessorRegions( + ArrayRef operands, SmallVectorImpl ®ions) { + FoldAdaptor adaptor(operands); + + WhileOp whileOp = getParentOp(); + + // Condition can either lead to the after region or back to the parent op + // depending on whether the condition is true or not. + auto boolAttr = dyn_cast_or_null(adaptor.getCondition()); + if (!boolAttr || boolAttr.getValue()) + regions.emplace_back(&whileOp.getAfter(), + whileOp.getAfter().getArguments()); + if (!boolAttr || !boolAttr.getValue()) + regions.emplace_back(whileOp.getResults()); +} + //===----------------------------------------------------------------------===// // ForOp //===----------------------------------------------------------------------===// @@ -535,7 +550,7 @@ /// Return operands used when entering the region at 'index'. These operands /// correspond to the loop iterator operands, i.e., those excluding the /// induction variable. -OperandRange ForOp::getSuccessorEntryOperands(std::optional index) { +OperandRange ForOp::getEntrySuccessorOperands(std::optional index) { return getInitArgs(); } @@ -545,7 +560,6 @@ /// correspond to a constant value for each operand, or null if that operand is /// not a constant. void ForOp::getSuccessorRegions(std::optional index, - ArrayRef operands, SmallVectorImpl ®ions) { // Both the operation itself and the region may be branching into the body or // back into the operation itself. It is possible for loop not to enter the @@ -1715,7 +1729,6 @@ /// correspond to a constant value for each operand, or null if that operand is /// not a constant. void ForallOp::getSuccessorRegions(std::optional index, - ArrayRef operands, SmallVectorImpl ®ions) { // Both the operation itself and the region may be branching into the body or // back into the operation itself. It is possible for loop not to enter the @@ -1996,7 +2009,6 @@ /// correspond to a constant value for each operand, or null if that operand is /// not a constant. void IfOp::getSuccessorRegions(std::optional index, - ArrayRef operands, SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. if (index) { @@ -2004,29 +2016,30 @@ return; } + regions.push_back(RegionSuccessor(&getThenRegion())); + // Don't consider the else region if it is empty. Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - elseRegion = nullptr; - - // Otherwise, the successor is dependent on the condition. - bool condition; - if (auto condAttr = llvm::dyn_cast_or_null(operands.front())) { - condition = condAttr.getValue().isOne(); - } else { - // If the condition isn't constant, both regions may be executed. - regions.push_back(RegionSuccessor(&getThenRegion())); - // If the else region does not exist, it is not a viable successor, so the - // control will go back to this operation instead. - if (elseRegion) - regions.push_back(RegionSuccessor(elseRegion)); + regions.push_back(RegionSuccessor()); + else + regions.push_back(RegionSuccessor(elseRegion)); +} + +void IfOp::getEntrySuccessorRegions(ArrayRef operands, + SmallVectorImpl ®ions) { + FoldAdaptor adaptor(operands); + auto boolAttr = dyn_cast_or_null(adaptor.getCondition()); + if (!boolAttr || boolAttr.getValue()) + regions.emplace_back(&getThenRegion()); + + // If the else region is empty, execution continues after the parent op. + if (!boolAttr || !boolAttr.getValue()) { + if (!getElseRegion().empty()) + regions.emplace_back(&getElseRegion()); else - regions.push_back(RegionSuccessor()); - return; + regions.emplace_back(getResults()); } - - // Add the successor regions using the condition. - regions.push_back(RegionSuccessor(condition ? &getThenRegion() : elseRegion)); } LogicalResult IfOp::fold(FoldAdaptor adaptor, @@ -3026,8 +3039,7 @@ /// correspond to a constant value for each operand, or null if that operand is /// not a constant. void ParallelOp::getSuccessorRegions( - std::optional index, ArrayRef operands, - SmallVectorImpl ®ions) { + std::optional index, SmallVectorImpl ®ions) { // Both the operation itself and the region may be branching into the body or // back into the operation itself. It is possible for loop not to enter the // body. @@ -3154,7 +3166,7 @@ afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments()); } -OperandRange WhileOp::getSuccessorEntryOperands(std::optional index) { +OperandRange WhileOp::getEntrySuccessorOperands(std::optional index) { assert(index && *index == 0 && "WhileOp is expected to branch only to the first region"); @@ -3178,7 +3190,6 @@ } void WhileOp::getSuccessorRegions(std::optional index, - ArrayRef operands, SmallVectorImpl ®ions) { // The parent op always branches to the condition region. if (!index) { @@ -3193,13 +3204,8 @@ return; } - // Try to narrow the successor to the condition region. - assert(!operands.empty() && "expected at least one operand"); - auto cond = llvm::dyn_cast_or_null(operands[0]); - if (!cond || !cond.getValue()) - regions.emplace_back(getResults()); - if (!cond || cond.getValue()) - regions.emplace_back(&getAfter(), getAfter().getArguments()); + regions.emplace_back(getResults()); + regions.emplace_back(&getAfter(), getAfter().getArguments()); } /// Parses a `while` op. @@ -4016,7 +4022,7 @@ } void IndexSwitchOp::getSuccessorRegions( - std::optional index, ArrayRef operands, + std::optional index, SmallVectorImpl &successors) { // All regions branch back to the parent op. if (index) { @@ -4024,19 +4030,25 @@ return; } + llvm::copy(getRegions(), std::back_inserter(successors)); +} + +void IndexSwitchOp::getEntrySuccessorRegions( + ArrayRef operands, + SmallVectorImpl &successors) { + FoldAdaptor adaptor(operands); + // If a constant was not provided, all regions are possible successors. - auto operandValue = llvm::dyn_cast_or_null(operands.front()); - if (!operandValue) { - for (Region &caseRegion : getCaseRegions()) - successors.emplace_back(&caseRegion); - successors.emplace_back(&getDefaultRegion()); + auto arg = dyn_cast_or_null(adaptor.getArg()); + if (!arg) { + llvm::copy(getRegions(), std::back_inserter(successors)); return; } - // Otherwise, try to find a case with a matching value. If not, the default - // region is the only successor. + // Otherwise, try to find a case with a matching value. If not, the + // default region is the only successor. for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) { - if (caseValue == operandValue.getInt()) { + if (caseValue == arg.getInt()) { successors.emplace_back(&caseRegion); return; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -335,8 +335,7 @@ // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td void AssumingOp::getSuccessorRegions( - std::optional index, ArrayRef operands, - SmallVectorImpl ®ions) { + std::optional index, SmallVectorImpl ®ions) { // AssumingOp has unconditional control flow into the region and back to the // parent, so return the correct RegionSuccessor purely based on the index // being None or 0. diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -85,7 +85,7 @@ // AlternativesOp //===----------------------------------------------------------------------===// -OperandRange transform::AlternativesOp::getSuccessorEntryOperands( +OperandRange transform::AlternativesOp::getEntrySuccessorOperands( std::optional index) { if (index && getOperation()->getNumOperands() == 1) return getOperation()->getOperands(); @@ -94,8 +94,7 @@ } void transform::AlternativesOp::getSuccessorRegions( - std::optional index, ArrayRef operands, - SmallVectorImpl ®ions) { + std::optional index, SmallVectorImpl ®ions) { for (Region &alternative : llvm::drop_begin( getAlternatives(), index.has_value() ? *index + 1 : 0)) { regions.emplace_back(&alternative, !getOperands().empty() @@ -1162,8 +1161,7 @@ } void transform::ForeachOp::getSuccessorRegions( - std::optional index, ArrayRef operands, - SmallVectorImpl ®ions) { + std::optional index, SmallVectorImpl ®ions) { Region *bodyRegion = &getBody(); if (!index) { regions.emplace_back(bodyRegion, bodyRegion->getArguments()); @@ -1177,7 +1175,7 @@ } OperandRange -transform::ForeachOp::getSuccessorEntryOperands(std::optional index) { +transform::ForeachOp::getEntrySuccessorOperands(std::optional index) { // The iteration variable op handle is mapped to a subset (one op to be // precise) of the payload ops of the ForeachOp operand. assert(index && *index == 0 && "unexpected region index"); @@ -2182,7 +2180,7 @@ getPotentialTopLevelEffects(effects); } -OperandRange transform::SequenceOp::getSuccessorEntryOperands( +OperandRange transform::SequenceOp::getEntrySuccessorOperands( std::optional index) { assert(index && *index == 0 && "unexpected region index"); if (getOperation()->getNumOperands() > 0) @@ -2192,11 +2190,10 @@ } void transform::SequenceOp::getSuccessorRegions( - std::optional index, ArrayRef operands, - SmallVectorImpl ®ions) { + std::optional index, SmallVectorImpl ®ions) { if (!index) { Region *bodyRegion = &getBody(); - regions.emplace_back(bodyRegion, !operands.empty() + regions.emplace_back(bodyRegion, getNumOperands() != 0 ? bodyRegion->getArguments() : Block::BlockArgListType()); return; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5671,8 +5671,7 @@ } void WarpExecuteOnLane0Op::getSuccessorRegions( - std::optional index, ArrayRef operands, - SmallVectorImpl ®ions) { + std::optional index, SmallVectorImpl ®ions) { if (index) { regions.push_back(RegionSuccessor(getResults())); return; diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -154,7 +154,7 @@ auto inputTypesFromParent = [&](std::optional regionNo) -> TypeRange { - return regionInterface.getSuccessorEntryOperands(regionNo).getTypes(); + return regionInterface.getEntrySuccessorOperands(regionNo).getTypes(); }; // Verify types along control flow edges originating from the parent. @@ -309,27 +309,6 @@ return isRegionReachable(region, region); } -void RegionBranchOpInterface::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { - unsigned numInputs = 0; - if (index) { - // If the predecessor is a region, get the number of operands from an - // exiting terminator in the region. - for (Block &block : getOperation()->getRegion(*index)) { - Operation *terminator = block.getTerminator(); - if (isa(terminator)) { - numInputs = terminator->getNumOperands(); - break; - } - } - } else { - // Otherwise, use the number of parent operation operands. - numInputs = getOperation()->getNumOperands(); - } - SmallVector operands(numInputs, nullptr); - getSuccessorRegions(index, operands, regions); -} - Region *mlir::getEnclosingRepetitiveRegion(Operation *op) { while (Region *region = op->getParentRegion()) { op = region->getParentOp(); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -932,14 +932,13 @@ } OperandRange -RegionIfOp::getSuccessorEntryOperands(std::optional index) { +RegionIfOp::getEntrySuccessorOperands(std::optional index) { assert(index && *index < 2 && "invalid region index"); return getOperands(); } void RegionIfOp::getSuccessorRegions( - std::optional index, ArrayRef operands, - SmallVectorImpl ®ions) { + std::optional index, SmallVectorImpl ®ions) { // We always branch to the join region. if (index.has_value()) { if (index.value() < 2) @@ -966,7 +965,6 @@ //===----------------------------------------------------------------------===// void AnyCondOp::getSuccessorRegions(std::optional index, - ArrayRef operands, SmallVectorImpl ®ions) { // The parent op branches into the only region, and the region branches back // to the parent op. @@ -1268,8 +1266,7 @@ } void TestStoreWithARegion::getSuccessorRegions( - std::optional index, ArrayRef operands, - SmallVectorImpl ®ions) { + std::optional index, SmallVectorImpl ®ions) { if (!index) { regions.emplace_back(&getBody(), getBody().front().getArguments()); } else { @@ -1277,11 +1274,6 @@ } } -MutableOperandRange TestStoreWithARegionTerminator::getMutableSuccessorOperands( - std::optional index) { - return MutableOperandRange(getOperation()); -} - LogicalResult TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2047,7 +2047,8 @@ def RegionIfOp : TEST_Op<"region_if", [DeclareOpInterfaceMethods, + ["getRegionInvocationBounds", + "getEntrySuccessorOperands"]>, SingleBlockImplicitTerminator<"RegionIfYieldOp">, RecursiveMemoryEffects]> { let description =[{ @@ -2071,8 +2072,6 @@ ::mlir::Block::BlockArgListType getJoinArgs() { return getBody(2)->getArguments(); } - ::mlir::OperandRange getSuccessorEntryOperands( - ::std::optional index); }]; let hasCustomAssemblyFormat = 1; } @@ -2824,7 +2823,7 @@ } def TestStoreWithARegionTerminator : TEST_Op<"store_with_a_region_terminator", - [DeclareOpInterfaceMethods, Terminator, NoMemoryEffect]> { + [ReturnLike, Terminator, NoMemoryEffect]> { let assemblyFormat = "attr-dict"; } diff --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp --- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp @@ -38,7 +38,6 @@ // Regions have no successors. void getSuccessorRegions(std::optional index, - ArrayRef operands, SmallVectorImpl ®ions) {} }; @@ -53,7 +52,6 @@ static StringRef getOperationName() { return "cftest.loop_regions_op"; } void getSuccessorRegions(std::optional index, - ArrayRef operands, SmallVectorImpl ®ions) { if (index) { if (*index == 1) @@ -77,7 +75,6 @@ } void getSuccessorRegions(std::optional index, - ArrayRef operands, SmallVectorImpl ®ions) { if (index.has_value()) { regions.push_back(RegionSuccessor()); @@ -96,7 +93,6 @@ // Region 0 has Region 1 as a successor. void getSuccessorRegions(std::optional index, - ArrayRef operands, SmallVectorImpl ®ions) { if (index == 0u) { Operation *thisOp = this->getOperation();