diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1326,6 +1326,8 @@ } }; +/// Hoist any yielded results whose operands are defined outside +/// the if, to a select instruction. struct ConvertTrivialIfToSelect : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1334,31 +1336,58 @@ if (op->getNumResults() == 0) return failure(); - if (!llvm::hasSingleElement(op.getThenRegion().front()) || - !llvm::hasSingleElement(op.getElseRegion().front())) + auto cond = op.getCondition(); + auto thenYieldArgs = op.thenYield().getOperands(); + auto elseYieldArgs = op.elseYield().getOperands(); + + SmallVector nonHoistable; + for (const auto &it : + llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) { + Value trueVal = std::get<0>(it.value()); + Value falseVal = std::get<1>(it.value()); + if (&op.getThenRegion() == trueVal.getParentRegion() || + &op.getElseRegion() == falseVal.getParentRegion()) + nonHoistable.push_back(trueVal.getType()); + } + // Early exit if there aren't any yielded values we can + // hoist outside the if. + if (nonHoistable.size() == op->getNumResults()) return failure(); - auto cond = op.getCondition(); - auto thenYieldArgs = - cast(op.getThenRegion().front().getTerminator()) - .getOperands(); - auto elseYieldArgs = - cast(op.getElseRegion().front().getTerminator()) - .getOperands(); + IfOp replacement = rewriter.create(op.getLoc(), nonHoistable, cond); + if (replacement.thenBlock()) + rewriter.eraseBlock(replacement.thenBlock()); + replacement.getThenRegion().takeBody(op.getThenRegion()); + replacement.getElseRegion().takeBody(op.getElseRegion()); + SmallVector results(op->getNumResults()); assert(thenYieldArgs.size() == results.size()); assert(elseYieldArgs.size() == results.size()); + + SmallVector trueYields; + SmallVector falseYields; for (const auto &it : llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) { Value trueVal = std::get<0>(it.value()); Value falseVal = std::get<1>(it.value()); - if (trueVal == falseVal) + if (&replacement.getThenRegion() == trueVal.getParentRegion() || + &replacement.getElseRegion() == falseVal.getParentRegion()) { + results[it.index()] = replacement.getResult(trueYields.size()); + trueYields.push_back(trueVal); + falseYields.push_back(falseVal); + } else if (trueVal == falseVal) results[it.index()] = trueVal; else results[it.index()] = rewriter.create( op.getLoc(), cond, trueVal, falseVal); } + rewriter.setInsertionPointToEnd(replacement.thenBlock()); + rewriter.replaceOpWithNewOp(replacement.thenYield(), trueYields); + + rewriter.setInsertionPointToEnd(replacement.elseBlock()); + rewriter.replaceOpWithNewOp(replacement.elseYield(), falseYields); + rewriter.replaceOp(op, results); return success(); } diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -136,26 +136,26 @@ func private @side_effect() func @one_unused(%cond: i1) -> (index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index %0, %1 = scf.if %cond -> (index, index) { call @side_effect() : () -> () + %c0 = "test.value0"() : () -> (index) + %c1 = "test.value1"() : () -> (index) scf.yield %c0, %c1 : index, index } else { + %c2 = "test.value2"() : () -> (index) + %c3 = "test.value3"() : () -> (index) scf.yield %c2, %c3 : index, index } return %1 : index } // CHECK-LABEL: func @one_unused -// CHECK-DAG: [[C0:%.*]] = arith.constant 1 : index -// CHECK-DAG: [[C3:%.*]] = arith.constant 3 : index // CHECK: [[V0:%.*]] = scf.if %{{.*}} -> (index) { // CHECK: call @side_effect() : () -> () -// CHECK: scf.yield [[C0]] : index +// CHECK: [[C1:%.*]] = "test.value1" +// CHECK: scf.yield [[C1]] : index // CHECK: } else +// CHECK: [[C3:%.*]] = "test.value3" // CHECK: scf.yield [[C3]] : index // CHECK: } // CHECK: return [[V0]] : index @@ -164,37 +164,40 @@ func private @side_effect() func @nested_unused(%cond1: i1, %cond2: i1) -> (index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index %0, %1 = scf.if %cond1 -> (index, index) { %2, %3 = scf.if %cond2 -> (index, index) { call @side_effect() : () -> () + %c0 = "test.value0"() : () -> (index) + %c1 = "test.value1"() : () -> (index) scf.yield %c0, %c1 : index, index } else { + %c2 = "test.value2"() : () -> (index) + %c3 = "test.value3"() : () -> (index) scf.yield %c2, %c3 : index, index } scf.yield %2, %3 : index, index } else { + %c0 = "test.value0_2"() : () -> (index) + %c1 = "test.value1_2"() : () -> (index) scf.yield %c0, %c1 : index, index } return %1 : index } // CHECK-LABEL: func @nested_unused -// CHECK-DAG: [[C0:%.*]] = arith.constant 1 : index -// CHECK-DAG: [[C3:%.*]] = arith.constant 3 : index // CHECK: [[V0:%.*]] = scf.if {{.*}} -> (index) { // CHECK: [[V1:%.*]] = scf.if {{.*}} -> (index) { // CHECK: call @side_effect() : () -> () -// CHECK: scf.yield [[C0]] : index +// CHECK: [[C1:%.*]] = "test.value1" +// CHECK: scf.yield [[C1]] : index // CHECK: } else +// CHECK: [[C3:%.*]] = "test.value3" // CHECK: scf.yield [[C3]] : index // CHECK: } // CHECK: scf.yield [[V1]] : index // CHECK: } else -// CHECK: scf.yield [[C0]] : index +// CHECK: [[C1_2:%.*]] = "test.value1_2" +// CHECK: scf.yield [[C1_2]] : index // CHECK: } // CHECK: return [[V0]] : index @@ -302,6 +305,27 @@ // CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]] // CHECK: return [[V0]], [[C1]] : index, index + +func @to_select_with_body(%cond: i1) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = scf.if %cond -> index { + "test.op"() : () -> () + scf.yield %c0 : index + } else { + scf.yield %c1 : index + } + return %0 : index +} + +// CHECK-LABEL: func @to_select_with_body +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK: scf.if {{.*}} { +// CHECK: "test.op"() : () -> () +// CHECK: } +// CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]] +// CHECK: return [[V0]] : index // ----- func @to_select2(%cond: i1) -> (index, index) { @@ -731,38 +755,32 @@ // CHECK-LABEL: @cond_prop func @cond_prop(%arg0 : i1) -> index { - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index %res = scf.if %arg0 -> index { %res1 = scf.if %arg0 -> index { - %v1 = "test.get_some_value"() : () -> i32 - scf.yield %c1 : index + %v1 = "test.get_some_value1"() : () -> index + scf.yield %v1 : index } else { - %v2 = "test.get_some_value"() : () -> i32 - scf.yield %c2 : index + %v2 = "test.get_some_value2"() : () -> index + scf.yield %v2 : index } scf.yield %res1 : index } else { %res2 = scf.if %arg0 -> index { - %v3 = "test.get_some_value"() : () -> i32 - scf.yield %c3 : index + %v3 = "test.get_some_value3"() : () -> index + scf.yield %v3 : index } else { - %v4 = "test.get_some_value"() : () -> i32 - scf.yield %c4 : index + %v4 = "test.get_some_value4"() : () -> index + scf.yield %v4 : index } scf.yield %res2 : index } return %res : index } -// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index // CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) { -// CHECK-NEXT: %{{.+}} = "test.get_some_value"() : () -> i32 +// CHECK-NEXT: %[[c1:.+]] = "test.get_some_value1"() : () -> index // CHECK-NEXT: scf.yield %[[c1]] : index // CHECK-NEXT: } else { -// CHECK-NEXT: %{{.+}} = "test.get_some_value"() : () -> i32 +// CHECK-NEXT: %[[c4:.+]] = "test.get_some_value4"() : () -> index // CHECK-NEXT: scf.yield %[[c4]] : index // CHECK-NEXT: } // CHECK-NEXT: return %[[if]] : index @@ -808,7 +826,6 @@ return %res#0, %res#1 : i32, i1 } // CHECK-NEXT: %true = arith.constant true -// CHECK-NEXT: %[[toret:.+]] = arith.xori %arg0, %true : i1 // CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) { // CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32 // CHECK-NEXT: scf.yield %[[sv1]] : i32 @@ -816,6 +833,7 @@ // CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32 // CHECK-NEXT: scf.yield %[[sv2]] : i32 // CHECK-NEXT: } +// CHECK-NEXT: %[[toret:.+]] = arith.xori %arg0, %true : i1 // CHECK-NEXT: return %[[if]], %[[toret]] : i32, i1 // -----