diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -315,7 +315,9 @@ } def IfOp : SCF_Op<"if", - [DeclareOpInterfaceMethods, + [DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveSideEffects, NoRegionArguments]> { let summary = "if-then-else operation"; diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -140,9 +140,14 @@ of invocations cannot be statically determined, then it will not have a value (i.e., it is set to `llvm::None`). - `operands` is a set of optional attributes that either correspond to a - constant values for each operand of this operation, or null if that + `operands` is a set of optional attributes that either correspond to + constant values for each operand of this operation or null if that operand is not a constant. + + This method may be called speculatively on operations where the provided + operands are not necessarily the same as the operation's current + operands. This may occur in analyses that wish to determine "what would + be the region invocations if these were the operands?" }], "void", "getRegionInvocationBounds", (ins "::mlir::ArrayRef<::mlir::Attribute>":$operands, diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1188,6 +1188,20 @@ return success(); } +void IfOp::getRegionInvocationBounds( + ArrayRef operands, + SmallVectorImpl &invocationBounds) { + if (auto cond = operands[0].dyn_cast_or_null()) { + // If the condition is known, then one region is known to be executed once + // and the other zero times. + invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0); + invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1); + } else { + // Non-constant condition. Each region may be executed 0 or 1 times. + invocationBounds.assign(2, {0, 1}); + } +} + namespace { // Pattern to remove unused IfOp results. struct RemoveUnusedResults : public OpRewritePattern { diff --git a/mlir/test/Dialect/SCF/control-flow-sink.mlir b/mlir/test/Dialect/SCF/control-flow-sink.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/control-flow-sink.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt -split-input-file -control-flow-sink %s | FileCheck %s + +// CHECK-LABEL: @test_scf_if_sink +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32) +// CHECK: %[[V0:.*]] = scf.if %[[ARG0]] +// CHECK: %[[V1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] +// CHECK: scf.yield %[[V1]] +// CHECK: else +// CHECK: %[[V1:.*]] = arith.muli %[[ARG1]], %[[ARG1]] +// CHECK: scf.yield %[[V1]] +// CHECK: return %[[V0]] +func @test_scf_if_sink(%arg0: i1, %arg1: i32) -> i32 { + %0 = arith.addi %arg1, %arg1 : i32 + %1 = arith.muli %arg1, %arg1 : i32 + %result = scf.if %arg0 -> i32 { + scf.yield %0 : i32 + } else { + scf.yield %1 : i32 + } + return %result : i32 +} + +// ----- + +func private @consume(i32) -> () + +// CHECK-LABEL: @test_scf_if_then_only_sink +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32) +// CHECK: scf.if %[[ARG0]] +// CHECK: %[[V0:.*]] = arith.addi %[[ARG1]], %[[ARG1]] +// CHECK: call @consume(%[[V0]]) +func @test_scf_if_then_only_sink(%arg0: i1, %arg1: i32) { + %0 = arith.addi %arg1, %arg1 : i32 + scf.if %arg0 { + call @consume(%0) : (i32) -> () + scf.yield + } + return +} + +// ----- + +func private @consume(i32) -> () + +// CHECK-LABEL: @test_scf_if_double_sink +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32) +// CHECK: scf.if %[[ARG0]] +// CHECK: scf.if %[[ARG0]] +// CHECK: %[[V0:.*]] = arith.addi %[[ARG1]], %[[ARG1]] +// CHECK: call @consume(%[[V0]]) +func @test_scf_if_double_sink(%arg0: i1, %arg1: i32) { + %0 = arith.addi %arg1, %arg1 : i32 + scf.if %arg0 { + scf.if %arg0 { + call @consume(%0) : (i32) -> () + scf.yield + } + } + return +}