Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h =================================================================== --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -15,6 +15,7 @@ #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/Traits.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td =================================================================== --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -15,6 +15,7 @@ include "mlir/IR/OpBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Dialect/Tosa/IR/TosaInterfaces.td" @@ -1623,8 +1624,9 @@ // Further described in docs/Rationale/RationaleTOSADialect.md . //===----------------------------------------------------------------------===// def Tosa_IfOp : Tosa_Op<"cond_if", [ - SingleBlockImplicitTerminator<"YieldOp">, - RecursiveSideEffects]> { + DeclareOpInterfaceMethods, + RecursiveSideEffects, + SingleBlockImplicitTerminator<"YieldOp">]> { let summary = "Conditional if operator"; let description = [{ @@ -1655,8 +1657,9 @@ //===----------------------------------------------------------------------===// def Tosa_WhileOp : Tosa_Op<"while_loop", [ DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"YieldOp">, - RecursiveSideEffects]> { + DeclareOpInterfaceMethods, + RecursiveSideEffects, + SingleBlockImplicitTerminator<"YieldOp">]> { let summary = "output = input; While (Cond(output)) {output = Body(output)}"; let description = [{ Index: mlir/lib/Dialect/Tosa/IR/TosaOps.cpp =================================================================== --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -77,6 +77,59 @@ return success(); } +void IfOp::getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl ®ions) { + // The `then` and the `else` region branch back to the parent operation. + if (index.hasValue()) { + regions.push_back(RegionSuccessor(getResults())); + return; + } + + // Don't consider the else region if it is empty. + Region *elseBranch = &this->else_branch(); + if (elseBranch->empty()) + elseBranch = nullptr; + + // Otherwise, the successor is dependent on the condition. + bool condition; + if (auto condAttr = operands.front().dyn_cast_or_null()) { + condition = condAttr.getValue(); + } else { + // If the condition isn't constant, both regions may be executed. + regions.push_back(RegionSuccessor(&then_branch())); + regions.push_back(RegionSuccessor(elseBranch)); + return; + } + + // Add the successor regions using the condition. + regions.push_back(RegionSuccessor(condition ? &then_branch() : elseBranch)); +} + +void WhileOp::getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl ®ions) { + (void)operands; + + // Where 'cond' executes before 'body', i.e. while-cond-do-body construct. + if (!index.hasValue()) { + regions.emplace_back(&cond(), cond().getArguments()); + return; + } + + assert(*index < 2 && "there are only two regions in a WhileOp"); + // Add the 'body' region and ensure it can pass result to next instance of + // 'cond'. + if (*index == 0) { + regions.emplace_back(&body(), body().getArguments()); + regions.emplace_back(getResults()); + return; + } + + // Support do-body-while-cond construct by adding 'cond'. + regions.emplace_back(&cond(), cond().getArguments()); +} + //===----------------------------------------------------------------------===// // Tosa dialect initialization. //===----------------------------------------------------------------------===// Index: mlir/test/Dialect/Tosa/constant_folding.mlir =================================================================== --- mlir/test/Dialect/Tosa/constant_folding.mlir +++ mlir/test/Dialect/Tosa/constant_folding.mlir @@ -6,3 +6,43 @@ %0 = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> return %0 : tensor<4xi32> } + +// ----- + +// CHECK-LABEL: func @tosa_if_fn +// CHECK-NOT: @add +// CHECK-NOT: @sub +func @tosa_if_fn(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tosa.const"() {value = dense<1> : tensor } : () -> tensor + %1 = "tosa.cond_if"(%0, %arg0, %arg1) ( { + ^bb0: // no predecessors + %2 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tosa.yield"(%2) : (tensor) -> () + }, { + ^bb0: // no predecessors + %2 = "tosa.sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tosa.yield"(%2) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: func @tosa_while_fn +func @tosa_while_fn(%input: tensor<10xi32>) -> tensor<10xi32> { + %ctr = "tosa.const"() {value = dense<0> : tensor } : () -> tensor + %incr = "tosa.const"() {value = dense<1> : tensor } : () -> tensor + %limit = "tosa.const"() {value = dense<5> : tensor } : () -> tensor + %0:4 = "tosa.while_loop"(%ctr, %incr, %limit, %input) ( { + ^bb0: // no predecessors + %1 = "tosa.greater_equal"(%ctr, %limit) : (tensor, tensor) -> tensor + %2 = "tosa.logical_not"(%1) : (tensor) -> tensor + "tosa.yield"(%2) : (tensor) -> () + }, { + ^bb0: // no predecessors + %1 = "tosa.add"(%ctr, %incr) : (tensor, tensor) -> tensor + %2 = "tosa.add"(%input, %ctr) : (tensor<10xi32>, tensor) -> tensor<10xi32> + "tosa.yield"(%1, %incr, %limit, %2) : (tensor, tensor, tensor, tensor<10xi32>) -> () + }) : (tensor, tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) + return %0#3 : tensor<10xi32> +}