diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -319,7 +319,9 @@ } def IfOp : SCF_Op<"if", - [DeclareOpInterfaceMethods, + [DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveSideEffects, NoRegionArguments]> { let summary = "if-then-else operation"; @@ -403,12 +405,6 @@ YieldOp thenYield(); Block* elseBlock(); YieldOp elseYield(); - - /// If the condition is a constant, returns 1 for the executed block and 0 - /// for the other. Otherwise, returns `kUnknownNumRegionInvocations` for - /// both successors. - void getNumRegionInvocations(ArrayRef operands, - SmallVectorImpl &countPerRegion); }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1198,6 +1198,23 @@ } } +/// Each region of the op is invoked at most once. If the condition is constant, +/// then the number of invocations can be narrowed to either 0 or 1. +void IfOp::getRegionInvocationBounds( + ArrayRef operands, + SmallVectorImpl &invocationBounds) { + if (auto condAttr = operands.front().dyn_cast_or_null()) { + // If the condition is true, `then` is executed once and `else` zero times, + // and vice-versa. + bool cond = condAttr.getValue().isOneValue(); + invocationBounds.assign(1, {0, cond ? 1 : 0}); + invocationBounds.emplace_back(0, cond ? 0 : 1); + } else { + // Non-constant condition: unknown invocations for both successors. + invocationBounds.assign(2, {0, 1}); + } +} + namespace { // Pattern to remove unused IfOp results. struct RemoveUnusedResults : public OpRewritePattern { diff --git a/mlir/test/Transforms/control-flow-sink.mlir b/mlir/test/Transforms/control-flow-sink.mlir --- a/mlir/test/Transforms/control-flow-sink.mlir +++ b/mlir/test/Transforms/control-flow-sink.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -control-flow-sink %s | FileCheck %s +// RUN: mlir-opt -split-input-file -control-flow-sink %s | FileCheck %s // CHECK-LABEL: @test_simple_sink // CHECK: %0 = arith.subi %arg2, %arg1 : i32 @@ -32,6 +32,8 @@ return %result : i32 } +// ----- + // CHECK-LABEL: @test_region_sink // CHECK: %0 = test.region_if %arg0: i1 -> i32 then { // CHECK: %1 = test.region_if %arg0: i1 -> i32 then { @@ -68,6 +70,8 @@ return %result1 : i32 } +// ----- + // CHECK-LABEL: @test_subgraph_sink // CHECK: %0 = test.region_if %arg0: i1 -> i32 then { // CHECK: %1 = arith.subi %arg1, %arg2 : i32 @@ -100,6 +104,8 @@ return %result : i32 } +// ----- + // CHECK-LABEL: @test_multiblock_region_sink // CHECK: %0 = arith.addi %arg1, %arg2 : i32 // CHECK: %1 = "test.any_cond"() ( { @@ -128,3 +134,44 @@ %result = arith.addi %0, %3 : i32 return %result : i32 } + +// ----- + +// CHECK-LABEL: @test_scf_if_sink +// CHECK: %0 = scf.if %arg0 -> (i32) { +// CHECK: %1 = arith.addi %arg1, %arg1 : i32 +// CHECK: scf.yield %1 : i32 +// CHECK: } else { +// CHECK: %1 = arith.muli %arg1, %arg1 : i32 +// CHECK: scf.yield %1 : i32 +// CHECK: } +// CHECK: return %0 : i32 +func @test_scf_if_sink(%cond : i1, %arg0 : i32) -> i32 { + %0 = arith.addi %arg0, %arg0 : i32 + %1 = arith.muli %arg0, %arg0 : i32 + %result = scf.if %cond -> i32 { + scf.yield %0 : i32 + } else { + scf.yield %1 : i32 + } + return %result : i32 +} + +// ----- + +func private @consume(i32) -> () + +// CHECK-LABEL: @test_scf_if_then_only_sink +// CHECK: scf.if %arg0 { +// CHECK: %0 = arith.addi %arg1, %arg1 : i32 +// CHECK: call @consume(%0) : (i32) -> () +// CHECK: } +// CHECK: return +func @test_scf_if_then_only_sink(%cond : i1, %arg0 : i32) { + %0 = arith.addi %arg0, %arg0 : i32 + scf.if %cond { + call @consume(%0) : (i32) -> () + scf.yield + } + return +}