diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1768,6 +1768,10 @@ if (op.elseBlock()) llvm::append_range(elseYield, op.elseYield().getOperands()); + // A list of indices for which we should upgrade the value yielded + // in the else to a select. + SmallVector elseYieldsToUpgradeToSelect; + // If the outer scf.if yields a value produced by the inner scf.if, // only permit combining if the value yielded when the condition // is false in the outer scf.if is the same value yielded when the @@ -1785,6 +1789,22 @@ // If the correctness test passes, we will yield // corresponding value from the inner scf.if thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx); + continue; + } + + // Otherwise, we need to ensure the else block of the combined + // condition still returns the same value when the outer condition is + // true and the inner condition is false. This can be accomplished if + // the then value is defined outside the outer scf.if and we replace the + // value with a select that considers just the outer condition. Since + // the else region contains just the yield, its yielded value is + // defined outside the scf.if, by definition. + + // If the then value is defined within the scf.if, bail. + if (tup.value().getParentRegion() == &op.getThenRegion()) { + return failure(); + } else { + elseYieldsToUpgradeToSelect.push_back(tup.index()); } } @@ -1792,6 +1812,15 @@ Value newCondition = rewriter.create( loc, op.getCondition(), nestedIf.getCondition()); auto newIf = rewriter.create(loc, op.getResultTypes(), newCondition); + + SmallVector results; + llvm::append_range(results, newIf.getResults()); + rewriter.setInsertionPoint(newIf); + + for (auto idx : elseYieldsToUpgradeToSelect) + results[idx] = rewriter.create( + op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]); + Block *newIfBlock = newIf.thenBlock(); if (newIfBlock) rewriter.eraseOp(newIfBlock->getTerminator()); @@ -1805,7 +1834,7 @@ rewriter.setInsertionPointToEnd(newIf.elseBlock()); rewriter.create(loc, elseYield); } - rewriter.replaceOp(op, newIf.getResults()); + rewriter.replaceOp(op, results); return success(); } }; diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -556,7 +556,7 @@ // CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32 // CHECK: %[[PRE1:.*]] = "test.op1"() : () -> i32 // CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]] -// CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[PRE0]], %[[PRE1]] +// CHECK: %[[RES:.*]] = arith.select %[[ARG0]], %[[PRE0]], %[[PRE1]] // CHECK: scf.if %[[COND]] // CHECK: "test.run"() : () -> () // CHECK: } @@ -596,6 +596,7 @@ } return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8 } + // ----- // CHECK-LABEL: func @if_condition_swap