Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h =================================================================== --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -15,6 +15,7 @@ #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/Traits.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td =================================================================== --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -15,6 +15,7 @@ include "mlir/IR/OpBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Dialect/Tosa/IR/TosaInterfaces.td" @@ -1623,8 +1624,9 @@ // Further described in docs/Rationale/RationaleTOSADialect.md . //===----------------------------------------------------------------------===// def Tosa_IfOp : Tosa_Op<"cond_if", [ - SingleBlockImplicitTerminator<"YieldOp">, - RecursiveSideEffects]> { + DeclareOpInterfaceMethods, + RecursiveSideEffects, + SingleBlockImplicitTerminator<"YieldOp">]> { let summary = "Conditional if operator"; let description = [{ @@ -1633,7 +1635,7 @@ }]; let arguments = (ins - I1Tensor:$cond, + I1:$cond, Variadic:$inputs ); @@ -1643,7 +1645,7 @@ let regions = (region SizedRegion<1>:$then_branch, - SizedRegion<1>:$else_branch + AnyRegion:$else_branch ); } @@ -1655,8 +1657,8 @@ //===----------------------------------------------------------------------===// def Tosa_WhileOp : Tosa_Op<"while_loop", [ DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"YieldOp">, - RecursiveSideEffects]> { + RecursiveSideEffects, + SingleBlockImplicitTerminator<"YieldOp">]> { let summary = "output = input; While (Cond(output)) {output = Body(output)}"; let description = [{ Index: mlir/lib/Dialect/Tosa/IR/TosaOps.cpp =================================================================== --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -77,6 +77,36 @@ return success(); } +void IfOp::getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl ®ions) { + // The `then` and the `else` region branch back to the parent operation. + if (index.hasValue()) { + regions.push_back(RegionSuccessor(getResults())); + return; + } + + // Don't consider the else region if it is empty. + Region *optionalElseBranch = &else_branch(); + if (optionalElseBranch->empty()) + optionalElseBranch = nullptr; + + // Otherwise, the successor is dependent on the condition. + bool condition; + if (auto condAttr = operands.front().dyn_cast_or_null()) { + condition = condAttr.getValue(); + } else { + // If the condition isn't constant, both regions may be executed. + regions.push_back(RegionSuccessor(&then_branch())); + regions.push_back(RegionSuccessor(optionalElseBranch)); + return; + } + + // Add the successor regions using the condition. + regions.push_back( + RegionSuccessor(condition ? &then_branch() : optionalElseBranch)); +} + //===----------------------------------------------------------------------===// // Tosa dialect initialization. //===----------------------------------------------------------------------===// Index: mlir/test/Dialect/Tosa/inlining.mlir =================================================================== --- mlir/test/Dialect/Tosa/inlining.mlir +++ mlir/test/Dialect/Tosa/inlining.mlir @@ -7,7 +7,7 @@ // Check that both the calls and the functions are eliminated after inlining: // CHECK-NOT: @add // CHECK-NOT: @sub -func @inlined_if_fn(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func @inlined_if_fn(%arg0: tensor, %arg1: tensor, %arg2: i1) -> tensor { %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ( { ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors %1 = call @add(%arg3, %arg4) : (tensor, tensor) -> tensor @@ -16,7 +16,7 @@ ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors %1 = call @sub(%arg3, %arg4) : (tensor, tensor) -> tensor "tosa.yield"(%1) : (tensor) -> () - }) : (tensor, tensor, tensor) -> tensor + }) : (i1, tensor, tensor) -> tensor return %0 : tensor } func @add(%arg0: tensor, %arg1: tensor) -> tensor attributes {sym_visibility = "private"} { Index: mlir/test/Dialect/Tosa/ops.mlir =================================================================== --- mlir/test/Dialect/Tosa/ops.mlir +++ mlir/test/Dialect/Tosa/ops.mlir @@ -477,7 +477,7 @@ // ----- // CHECK-LABEL: cond_if -func @test_cond_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func @test_cond_if(%arg0: tensor, %arg1: tensor, %arg2: i1) -> tensor { %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ( { ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors %1 = "tosa.add"(%arg3, %arg4) : (tensor, tensor) -> tensor @@ -486,7 +486,7 @@ ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors %1 = "tosa.sub"(%arg3, %arg4) : (tensor, tensor) -> tensor "tosa.yield"(%1) : (tensor) -> () - }) : (tensor, tensor, tensor) -> tensor + }) : (i1, tensor, tensor) -> tensor return %0 : tensor } Index: mlir/test/Dialect/Tosa/sccp.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/sccp.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-opt -allow-unregistered-dialect -pass-pipeline="func(sccp)" -split-input-file %s | FileCheck %s + +// These constitute three tests that validate constant propagation implemented +// through RegionBranchOpInterface. + +// Test 1: test a constant propagated along taken path identified at compile time. +// CHECK-LABEL: func @tosa_if_then_fn +func @tosa_if_then_fn() -> tensor { + %cond = constant true + // CHECK: %[[CSTA:.*]] = "tosa.const"() {value = dense<2.000000e+00> : tensor} : () -> tensor + %a = "tosa.const"() {value = dense<2.000000e+00> : tensor} : () -> tensor + %b = "tosa.const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %1 = "tosa.cond_if"(%cond) ( { + ^bb0: // no predecessors + // CHECK: "tosa.yield"(%[[CSTA]]) + "tosa.yield"(%a) : (tensor) -> () + }, { + ^bb0: // no predecessors + "tosa.yield"(%b) : (tensor) -> () + }) : (i1) -> tensor + return %1 : tensor +} + +// ----- + +// Test 2: test a constant propagated along not-taken path identified at compile time. +// CHECK-LABEL: func @tosa_if_else_fn +func @tosa_if_else_fn() -> tensor { + %cond = constant false + %a = "tosa.const"() {value = dense<2.000000e+00> : tensor} : () -> tensor + // CHECK: %[[CSTB:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %b = "tosa.const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %1 = "tosa.cond_if"(%cond) ( { + ^bb0: // no predecessors + "tosa.yield"(%a) : (tensor) -> () + }, { + ^bb0: // no predecessors + // CHECK: "tosa.yield"(%[[CSTB]]) + "tosa.yield"(%b) : (tensor) -> () + }) : (i1) -> tensor + return %1 : tensor +} + +// ----- + +// Test 3: condition cannot be evaluated at compile time +// CHECK-LABEL: func @tosa_if_then_else_fn +func @tosa_if_then_else_fn(%arg0 : i1, %arg1 : tensor) -> tensor { + // CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: tensor) + // CHECK: %[[CST:.*]] = "tosa.const"() {value = dense<2> : tensor} : () -> tensor + %cst = "tosa.const"() {value = dense<2> : tensor} : () -> tensor + %1 = "tosa.cond_if"(%arg0, %arg1) ( { + ^bb0: // no predecessors + // CHECK: "tosa.add" + %2 = "tosa.add"(%arg1, %cst) : (tensor, tensor) -> tensor + "tosa.yield"(%2) : (tensor) -> () + }, { + ^bb0: // no predecessors + // CHECK: "tosa.sub" + %2 = "tosa.sub"(%cst, %arg1) : (tensor, tensor) -> tensor + "tosa.yield"(%2) : (tensor) -> () + }) : (i1, tensor) -> tensor + return %1 : tensor +} +