diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1746,27 +1746,65 @@ LogicalResult matchAndRewrite(IfOp op, PatternRewriter &rewriter) const override { - // Both `if` ops must not yield results and have only `then` block. - if (op->getNumResults() != 0 || op.elseBlock()) - return failure(); - auto nestedOps = op.thenBlock()->without_terminator(); // Nested `if` must be the only op in block. if (!llvm::hasSingleElement(nestedOps)) return failure(); + // If there is an else block, it can only yield + if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock())) + return failure(); + auto nestedIf = dyn_cast(*nestedOps.begin()); - if (!nestedIf || nestedIf->getNumResults() != 0 || nestedIf.elseBlock()) + if (!nestedIf) + return failure(); + + if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock())) return failure(); + SmallVector thenYield(op.thenYield().getOperands()); + SmallVector elseYield; + if (op.elseBlock()) + llvm::append_range(elseYield, op.elseYield().getOperands()); + + // If the outer scf.if yields a value produced by the inner scf.if, + // only permit combining if the value yielded when the condition + // is false in the outer scf.if is the same value yielded when the + // inner scf.if condition is false. + // Note that the array access to elseYield will not go out of bounds + // since it must have the same length as thenYield, since they both + // come from the same scf.if. + for (auto tup : llvm::enumerate(thenYield)) { + if (tup.value().getDefiningOp() == nestedIf) { + auto nestedIdx = tup.value().cast().getResultNumber(); + if (nestedIf.elseYield().getOperand(nestedIdx) != + elseYield[tup.index()]) { + return failure(); + } + // If the correctness test passes, we will yield + // corresponding value from the inner scf.if + thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx); + } + } + Location loc = op.getLoc(); Value newCondition = rewriter.create( loc, op.getCondition(), nestedIf.getCondition()); - auto newIf = rewriter.create(loc, newCondition); + auto newIf = rewriter.create(loc, op.getResultTypes(), newCondition); Block *newIfBlock = newIf.thenBlock(); - rewriter.eraseOp(newIfBlock->getTerminator()); + if (newIfBlock) + rewriter.eraseOp(newIfBlock->getTerminator()); + else + newIfBlock = rewriter.createBlock(&newIf.getThenRegion()); rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock); - rewriter.eraseOp(op); + rewriter.setInsertionPointToEnd(newIf.thenBlock()); + rewriter.replaceOpWithNewOp(newIf.thenYield(), thenYield); + if (!elseYield.empty()) { + rewriter.createBlock(&newIf.getElseRegion()); + rewriter.setInsertionPointToEnd(newIf.elseBlock()); + rewriter.create(loc, elseYield); + } + rewriter.replaceOp(op, newIf.getResults()); return success(); } }; diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -491,6 +491,113 @@ // ----- +// CHECK-LABEL: @merge_yielding_nested_if +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1) +func @merge_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8) { +// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32 +// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> f32 +// CHECK: %[[PRE2:.*]] = "test.op2"() : () -> i32 +// CHECK: %[[PRE3:.*]] = "test.op3"() : () -> i8 +// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]] +// CHECK: %[[RES:.*]]:2 = scf.if %[[COND]] -> (f32, i32) +// CHECK: %[[IN0:.*]] = "test.inop"() : () -> i32 +// CHECK: %[[IN1:.*]] = "test.inop1"() : () -> f32 +// CHECK: scf.yield %[[IN1]], %[[IN0]] : f32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[PRE1]], %[[PRE2]] : f32, i32 +// CHECK: } +// CHECK: return %[[PRE0]], %[[RES]]#0, %[[RES]]#1, %[[PRE3]] : i32, f32, i32, i8 + %0 = "test.op"() : () -> (i32) + %1 = "test.op1"() : () -> (f32) + %2 = "test.op2"() : () -> (i32) + %3 = "test.op3"() : () -> (i8) + %r:4 = scf.if %arg0 -> (i32, f32, i32, i8) { + %a:2 = scf.if %arg1 -> (i32, f32) { + %i = "test.inop"() : () -> (i32) + %i1 = "test.inop1"() : () -> (f32) + scf.yield %i, %i1 : i32, f32 + } else { + scf.yield %2, %1 : i32, f32 + } + scf.yield %0, %a#1, %a#0, %3 : i32, f32, i32, i8 + } else { + scf.yield %0, %1, %2, %3 : i32, f32, i32, i8 + } + return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8 +} + +// CHECK-LABEL: @merge_yielding_nested_if_nv1 +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1) +func @merge_yielding_nested_if_nv1(%arg0: i1, %arg1: i1) { +// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32 +// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> f32 +// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]] +// CHECK: scf.if %[[COND]] +// CHECK: %[[IN0:.*]] = "test.inop"() : () -> i32 +// CHECK: %[[IN1:.*]] = "test.inop1"() : () -> f32 +// CHECK: } + %0 = "test.op"() : () -> (i32) + %1 = "test.op1"() : () -> (f32) + scf.if %arg0 { + %a:2 = scf.if %arg1 -> (i32, f32) { + %i = "test.inop"() : () -> (i32) + %i1 = "test.inop1"() : () -> (f32) + scf.yield %i, %i1 : i32, f32 + } else { + scf.yield %0, %1 : i32, f32 + } + } + return +} + +// CHECK-LABEL: @merge_yielding_nested_if_nv2 +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1) +func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 { +// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32 +// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> i32 +// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]] +// CHECK: scf.if %[[COND]] +// CHECK: "test.run"() : () -> () +// CHECK: } +// CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[PRE0]], %[[PRE1]] +// CHECK: return %[[RES]] + %0 = "test.op"() : () -> (i32) + %1 = "test.op1"() : () -> (i32) + %r = scf.if %arg0 -> i32 { + scf.if %arg1 { + "test.run"() : () -> () + } + scf.yield %0 : i32 + } else { + scf.yield %1 : i32 + } + return %r : i32 +} + +// CHECK-LABEL: @merge_fail_yielding_nested_if +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1) +func @merge_fail_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8) { +// CHECK-NOT: andi + %0 = "test.op"() : () -> (i32) + %1 = "test.op1"() : () -> (f32) + %2 = "test.op2"() : () -> (i32) + %3 = "test.op3"() : () -> (i8) + %r:4 = scf.if %arg0 -> (i32, f32, i32, i8) { + %a:2 = scf.if %arg1 -> (i32, f32) { + %i = "test.inop"() : () -> (i32) + %i1 = "test.inop1"() : () -> (f32) + scf.yield %i, %i1 : i32, f32 + } else { + scf.yield %0, %1 : i32, f32 + } + scf.yield %0, %a#1, %a#0, %3 : i32, f32, i32, i8 + } else { + scf.yield %0, %1, %2, %3 : i32, f32, i32, i8 + } + return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8 +} +// ----- + // CHECK-LABEL: func @if_condition_swap // CHECK-NEXT: %{{.*}} = scf.if %arg0 -> (index) { // CHECK-NEXT: %[[i1:.+]] = "test.origFalse"() : () -> index