diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -363,6 +363,7 @@ AutomaticAllocationScope, RecursiveMemoryEffects, SingleBlockImplicitTerminator<"scf::InParallelOp">, + DeclareOpInterfaceMethods, ]> { let summary = "evaluate a block multiple times in parallel"; let description = [{ @@ -814,6 +815,7 @@ AttrSizedOperandSegments, DeclareOpInterfaceMethods, RecursiveMemoryEffects, + DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"scf::YieldOp">]> { let summary = "parallel for operation"; let description = [{ diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1682,6 +1682,25 @@ ForallOpSingleOrZeroIterationDimsFolder>(context); } +/// Given the region at `index`, or the parent operation if `index` is None, +/// return the successor regions. These are the regions that may be selected +/// during the flow of control. `operands` is a set of optional attributes that +/// correspond to a constant value for each operand, or null if that operand is +/// not a constant. +void ForallOp::getSuccessorRegions(std::optional index, + ArrayRef operands, + SmallVectorImpl ®ions) { + // If the predecessor is ForallOp, branch into the body with empty arguments. + if (!index) { + regions.push_back(RegionSuccessor(&getRegion())); + return; + } + + // Otherwise, the loop should branch back to the parent operation. + assert(*index == 0 && "expected loop region"); + regions.push_back(RegionSuccessor()); +} + //===----------------------------------------------------------------------===// // InParallelOp //===----------------------------------------------------------------------===// @@ -2976,6 +2995,26 @@ context); } +/// Given the region at `index`, or the parent operation if `index` is None, +/// return the successor regions. These are the regions that may be selected +/// during the flow of control. `operands` is a set of optional attributes that +/// correspond to a constant value for each operand, or null if that operand is +/// not a constant. +void ParallelOp::getSuccessorRegions(std::optional index, + ArrayRef operands, + SmallVectorImpl ®ions) { + // If the predecessor is ParallelOp, branch into the body with empty + // arguments. + if (!index) { + regions.push_back(RegionSuccessor(&getRegion())); + return; + } + + assert(*index == 0 && "expected loop region"); + // Otherwise, the loop should branch back to the parent operation. + regions.push_back(RegionSuccessor()); +} + //===----------------------------------------------------------------------===// // ReduceOp //===----------------------------------------------------------------------===//