diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -315,7 +315,9 @@ } def IfOp : SCF_Op<"if", - [DeclareOpInterfaceMethods, + [DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveSideEffects, NoRegionArguments]> { let summary = "if-then-else operation"; diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1189,6 +1189,14 @@ return success(); } +void IfOp::getRegionInvocationBounds( + ArrayRef operands, + SmallVectorImpl &invocationBounds) { + // If the condition was constant, then the op would have already been folded. + // Each region of the op is invoked at most once. + invocationBounds.assign(2, {0, 1}); +} + namespace { // Pattern to remove unused IfOp results. struct RemoveUnusedResults : public OpRewritePattern { diff --git a/mlir/test/Dialect/SCF/control-flow-sink.mlir b/mlir/test/Dialect/SCF/control-flow-sink.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/control-flow-sink.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt -split-input-file -control-flow-sink %s | FileCheck %s + +// CHECK-LABEL: @test_scf_if_sink +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32) +// CHECK: %[[V0:.*]] = scf.if %[[ARG0]] +// CHECK: %[[V1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] +// CHECK: scf.yield %[[V1]] +// CHECK: else +// CHECK: %[[V1:.*]] = arith.muli %[[ARG1]], %[[ARG1]] +// CHECK: scf.yield %[[V1]] +// CHECK: return %[[V0]] +func @test_scf_if_sink(%arg0: i1, %arg1: i32) -> i32 { + %0 = arith.addi %arg1, %arg1 : i32 + %1 = arith.muli %arg1, %arg1 : i32 + %result = scf.if %arg0 -> i32 { + scf.yield %0 : i32 + } else { + scf.yield %1 : i32 + } + return %result : i32 +} + +// ----- + +func private @consume(i32) -> () + +// CHECK-LABEL: @test_scf_if_then_only_sink +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32) +// CHECK: scf.if %[[ARG0]] +// CHECK: %[[V0:.*]] = arith.addi %[[ARG1]], %[[ARG1]] +// CHECK: call @consume(%[[V0]]) +func @test_scf_if_then_only_sink(%arg0: i1, %arg1: i32) { + %0 = arith.addi %arg1, %arg1 : i32 + scf.if %arg0 { + call @consume(%0) : (i32) -> () + scf.yield + } + return +} + +// ----- + +func private @consume(i32) -> () + +// CHECK-LABEL: @test_scf_if_double_sink +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32) +// CHECK: scf.if %[[ARG0]] +// CHECK: scf.if %[[ARG0]] +// CHECK: %[[V0:.*]] = arith.addi %[[ARG1]], %[[ARG1]] +// CHECK: call @consume(%[[V0]]) +func @test_scf_if_double_sink(%arg0: i1, %arg1: i32) { + %0 = arith.addi %arg1, %arg1 : i32 + scf.if %arg0 { + scf.if %arg0 { + call @consume(%0) : (i32) -> () + scf.yield + } + } + return +}