diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1704,27 +1704,67 @@ LogicalResult matchAndRewrite(IfOp op, PatternRewriter &rewriter) const override { - // Both `if` ops must not yield results and have only `then` block. - if (op->getNumResults() != 0 || op.elseBlock()) - return failure(); - auto nestedOps = op.thenBlock()->without_terminator(); // Nested `if` must be the only op in block. if (!llvm::hasSingleElement(nestedOps)) return failure(); + // If there is an else block, it can only yield + if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock())) + return failure(); + auto nestedIf = dyn_cast(*nestedOps.begin()); - if (!nestedIf || nestedIf->getNumResults() != 0 || nestedIf.elseBlock()) + if (!nestedIf) + return failure(); + + if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock())) return failure(); + SmallVector thenYield(op.thenYield().getOperands()); + SmallVector elseYield; + if (op.elseBlock()) + for (auto v : op.elseYield().getOperands()) + elseYield.push_back(v); + + // If the outer if yields a value produced by the inner if, + // only permit combining if the value yielded if the condition + // is false in the outer if is the same value yielded if the + // inner if condition is false. + for (auto tup : llvm::enumerate(thenYield)) { + if (tup.value().getDefiningOp() == nestedIf) { + for (auto nesttup : llvm::enumerate(nestedIf.getResults())) + if (tup.value() == nesttup.value()) { + if (nestedIf.elseYield().getOperand(nesttup.index()) != + elseYield[tup.index()]) { + return failure(); + } + // If the correctness test passes, we will yield + // corresponding value from the inner if + thenYield[tup.index()] = + nestedIf.thenYield().getOperand(nesttup.index()); + break; + } + } + } + Location loc = op.getLoc(); Value newCondition = rewriter.create( loc, op.getCondition(), nestedIf.getCondition()); - auto newIf = rewriter.create(loc, newCondition); + auto newIf = rewriter.create(loc, op.getResultTypes(), newCondition); Block *newIfBlock = newIf.thenBlock(); - rewriter.eraseOp(newIfBlock->getTerminator()); + if (newIfBlock) + rewriter.eraseOp(newIfBlock->getTerminator()); + else + newIfBlock = rewriter.createBlock(&newIf.getThenRegion()); rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock); - rewriter.eraseOp(op); + rewriter.setInsertionPointToEnd(newIf.thenBlock()); + rewriter.replaceOpWithNewOp(newIf.thenYield(), thenYield); + if (elseYield.size()) { + rewriter.createBlock(&newIf.getElseRegion()); + rewriter.setInsertionPointToEnd(newIf.elseBlock()); + rewriter.create(loc, elseYield); + } + rewriter.replaceOp(op, newIf.getResults()); return success(); } }; diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -447,6 +447,65 @@ // ----- +// CHECK-LABEL: @merge_yielding_nested_if +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1) +func @merge_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8) { +// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32 +// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> f32 +// CHECK: %[[PRE2:.*]] = "test.op2"() : () -> i32 +// CHECK: %[[PRE3:.*]] = "test.op3"() : () -> i8 +// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]] +// CHECK: %[[RES:.*]]:2 = scf.if %[[COND]] -> (f32, i32) +// CHECK: %[[IN0:.*]] = "test.inop"() : () -> i32 +// CHECK: %[[IN1:.*]] = "test.inop1"() : () -> f32 +// CHECK: scf.yield %[[IN1]], %[[IN0]] : f32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[PRE1]], %[[PRE2]] : f32, i32 +// CHECK: } +// CHECK: return %[[PRE0]], %[[RES]]#0, %[[RES]]#1, %[[PRE3]] : i32, f32, i32, i8 + %0 = "test.op"() : () -> (i32) + %1 = "test.op1"() : () -> (f32) + %2 = "test.op2"() : () -> (i32) + %3 = "test.op3"() : () -> (i8) + %r:4 = scf.if %arg0 -> (i32, f32, i32, i8) { + %a:2 = scf.if %arg1 -> (i32, f32) { + %i = "test.inop"() : () -> (i32) + %i1 = "test.inop1"() : () -> (f32) + scf.yield %i, %i1 : i32, f32 + } else { + scf.yield %2, %1 : i32, f32 + } + scf.yield %0, %a#1, %a#0, %3 : i32, f32, i32, i8 + } else { + scf.yield %0, %1, %2, %3 : i32, f32, i32, i8 + } + return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8 +} + +// CHECK-LABEL: @merge_fail_yielding_nested_if +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1) +func @merge_fail_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8) { +// CHECK-NOT: andi + %0 = "test.op"() : () -> (i32) + %1 = "test.op1"() : () -> (f32) + %2 = "test.op2"() : () -> (i32) + %3 = "test.op3"() : () -> (i8) + %r:4 = scf.if %arg0 -> (i32, f32, i32, i8) { + %a:2 = scf.if %arg1 -> (i32, f32) { + %i = "test.inop"() : () -> (i32) + %i1 = "test.inop1"() : () -> (f32) + scf.yield %i, %i1 : i32, f32 + } else { + scf.yield %0, %1 : i32, f32 + } + scf.yield %0, %a#1, %a#0, %3 : i32, f32, i32, i8 + } else { + scf.yield %0, %1, %2, %3 : i32, f32, i32, i8 + } + return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8 +} +// ----- + // CHECK-LABEL: func @if_condition_swap // CHECK-NEXT: %{{.*}} = scf.if %arg0 -> (index) { // CHECK-NEXT: %[[i1:.+]] = "test.origFalse"() : () -> index