diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -319,7 +319,9 @@ } def IfOp : SCF_Op<"if", - [DeclareOpInterfaceMethods, + [DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveSideEffects, NoRegionArguments]> { let summary = "if-then-else operation"; @@ -403,12 +405,6 @@ YieldOp thenYield(); Block* elseBlock(); YieldOp elseYield(); - - /// If the condition is a constant, returns 1 for the executed block and 0 - /// for the other. Otherwise, returns `kUnknownNumRegionInvocations` for - /// both successors. - void getNumRegionInvocations(ArrayRef operands, - SmallVectorImpl &countPerRegion); }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1198,6 +1198,23 @@ } } +/// Each region of the op is invoked at most once. If the condition is constant, +/// then the number of invocations can be narrowed to either 0 or 1. +void IfOp::getRegionInvocationBounds( + ArrayRef operands, + SmallVectorImpl &invocationBounds) { + if (auto condAttr = operands.front().dyn_cast_or_null()) { + // If the condition is true, `then` is executed once and `else` zero times, + // and vice-versa. + bool cond = condAttr.getValue().isOneValue(); + invocationBounds.assign(1, {0, cond ? 1 : 0}); + invocationBounds.emplace_back(0, cond ? 0 : 1); + } else { + // Non-constant condition: unknown invocations for both successors. + invocationBounds.assign(2, {0, 1}); + } +} + namespace { // Pattern to remove unused IfOp results. struct RemoveUnusedResults : public OpRewritePattern { diff --git a/mlir/test/Dialect/SCF/control-flow-sink.mlir b/mlir/test/Dialect/SCF/control-flow-sink.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/control-flow-sink.mlir @@ -0,0 +1,63 @@ +// RUN: mlir-opt -split-input-file -control-flow-sink %s | FileCheck %s + +// CHECK-LABEL: @test_scf_if_sink +// CHECK: %0 = scf.if %arg0 -> (i32) { +// CHECK: %1 = arith.addi %arg1, %arg1 : i32 +// CHECK: scf.yield %1 : i32 +// CHECK: } else { +// CHECK: %1 = arith.muli %arg1, %arg1 : i32 +// CHECK: scf.yield %1 : i32 +// CHECK: } +// CHECK: return %0 : i32 +func @test_scf_if_sink(%cond : i1, %arg0 : i32) -> i32 { + %0 = arith.addi %arg0, %arg0 : i32 + %1 = arith.muli %arg0, %arg0 : i32 + %result = scf.if %cond -> i32 { + scf.yield %0 : i32 + } else { + scf.yield %1 : i32 + } + return %result : i32 +} + +// ----- + +func private @consume(i32) -> () + +// CHECK-LABEL: @test_scf_if_then_only_sink +// CHECK: scf.if %arg0 { +// CHECK: %0 = arith.addi %arg1, %arg1 : i32 +// CHECK: call @consume(%0) : (i32) -> () +// CHECK: } +// CHECK: return +func @test_scf_if_then_only_sink(%cond : i1, %arg0 : i32) { + %0 = arith.addi %arg0, %arg0 : i32 + scf.if %cond { + call @consume(%0) : (i32) -> () + scf.yield + } + return +} + +// ----- + +func private @consume(i32) -> () + +// CHECK-LABEL: @test_scf_if_double_sink +// CHECK: scf.if %arg0 { +// CHECK: scf.if %arg0 { +// CHECK: %0 = arith.addi %arg1, %arg1 : i32 +// CHECK: call @consume(%0) : (i32) -> () +// CHECK: } +// CHECK: } +// CHECK: return +func @test_scf_if_double_sink(%cond : i1, %arg0 : i32) { + %0 = arith.addi %arg0, %arg0 : i32 + scf.if %cond { + scf.if %cond { + call @consume(%0) : (i32) -> () + scf.yield + } + } + return +}