diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -15,6 +15,7 @@ #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/Traits.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -15,6 +15,7 @@ include "mlir/IR/OpBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Dialect/Tosa/IR/TosaInterfaces.td" @@ -1624,6 +1625,7 @@ //===----------------------------------------------------------------------===// def Tosa_IfOp : Tosa_Op<"cond_if", [ SingleBlockImplicitTerminator<"YieldOp">, + DeclareOpInterfaceMethods, RecursiveSideEffects]> { let summary = "Conditional if operator"; @@ -1655,6 +1657,7 @@ //===----------------------------------------------------------------------===// def Tosa_WhileOp : Tosa_Op<"while_loop", [ DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects]> { let summary = "output = input; While (Cond(output)) {output = Body(output)}"; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -77,6 +77,55 @@ return success(); } +void IfOp::getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl ®ions) { + // The `then` and the `else` region branch back to the parent operation. + if (index.hasValue()) { + regions.push_back(RegionSuccessor(getResults())); + return; + } + + // Don't consider the else region if it is empty. + Region *else_branch = &this->else_branch(); + if (else_branch->empty()) + else_branch = nullptr; + + // Otherwise, the successor is dependent on the condition. + bool condition; + if (auto condAttr = operands.front().dyn_cast_or_null()) { + condition = condAttr.getValue().isOneValue(); + } else { + // If the condition isn't constant, both regions may be executed. + regions.push_back(RegionSuccessor(&then_branch())); + regions.push_back(RegionSuccessor(else_branch)); + return; + } + + // Add the successor regions using the condition. + regions.push_back(RegionSuccessor(condition ? &then_branch() : else_branch)); +} + +void WhileOp::getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl ®ions) { + (void)operands; + + if (!index.hasValue()) { + regions.emplace_back(&cond(), cond().getArguments()); + return; + } + + assert(*index < 2 && "there are only two regions in a WhileOp"); + if (*index == 0) { + regions.emplace_back(&body(), body().getArguments()); + regions.emplace_back(getResults()); + return; + } + + regions.emplace_back(&cond(), cond().getArguments()); +} + //===----------------------------------------------------------------------===// // Tosa dialect initialization. //===----------------------------------------------------------------------===//