diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1513,17 +1513,33 @@ if (nextIf.getCondition() != prevIf.getCondition()) return failure(); - // Don't permit merging if a result of the first if is used - // within the second. - if (llvm::any_of(prevIf->getUsers(), - [&](Operation *user) { return nextIf->isAncestor(user); })) - return failure(); + SmallVector prevElseYielded; + if (!prevIf.getElseRegion().empty()) + prevElseYielded = prevIf.elseYield().getOperands(); + // Replace all uses of return values of op within nextIf with the + // corresponding yields + for (auto it : llvm::zip(prevIf.getResults(), + prevIf.thenYield().getOperands(), prevElseYielded)) + for (OpOperand &use : + llvm::make_early_inc_range(std::get<0>(it).getUses())) { + if (nextIf.getThenRegion().isAncestor( + use.getOwner()->getParentRegion())) { + rewriter.startRootUpdate(use.getOwner()); + use.set(std::get<1>(it)); + rewriter.finalizeRootUpdate(use.getOwner()); + } else if (nextIf.getElseRegion().isAncestor( + use.getOwner()->getParentRegion())) { + rewriter.startRootUpdate(use.getOwner()); + use.set(std::get<2>(it)); + rewriter.finalizeRootUpdate(use.getOwner()); + } + } SmallVector mergedTypes(prevIf.getResultTypes()); llvm::append_range(mergedTypes, nextIf.getResultTypes()); IfOp combinedIf = rewriter.create( - nextIf.getLoc(), mergedTypes, nextIf.getCondition(), /*hasElse=*/false); + nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false); rewriter.eraseBlock(&combinedIf.getThenRegion().back()); YieldOp thenYield = prevIf.thenYield(); diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1119,6 +1119,35 @@ // CHECK-NEXT: "test.secondCodeTrue"() : () -> () // CHECK-NEXT: } +// CHECK-LABEL: @combineIfsUsed +func @combineIfsUsed(%arg0 : i1, %arg2: i64) -> (i32, i32) { + %res = scf.if %arg0 -> i32 { + %v = "test.firstCodeTrue"() : () -> i32 + scf.yield %v : i32 + } else { + %v2 = "test.firstCodeFalse"() : () -> i32 + scf.yield %v2 : i32 + } + %res2 = scf.if %arg0 -> i32 { + %v = "test.secondCodeTrue"(%res) : (i32) -> i32 + scf.yield %v : i32 + } else { + %v2 = "test.secondCodeFalse"(%res) : (i32) -> i32 + scf.yield %v2 : i32 + } + return %res, %res2 : i32, i32 +} +// CHECK-NEXT: %[[res:.+]]:2 = scf.if %arg0 -> (i32, i32) { +// CHECK-NEXT: %[[tval0:.+]] = "test.firstCodeTrue"() : () -> i32 +// CHECK-NEXT: %[[tval:.+]] = "test.secondCodeTrue"(%[[tval0]]) : (i32) -> i32 +// CHECK-NEXT: scf.yield %[[tval0]], %[[tval]] : i32, i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[fval0:.+]] = "test.firstCodeFalse"() : () -> i32 +// CHECK-NEXT: %[[fval:.+]] = "test.secondCodeFalse"(%[[fval0]]) : (i32) -> i32 +// CHECK-NEXT: scf.yield %[[fval0]], %[[fval]] : i32, i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[res]]#0, %[[res]]#1 : i32, i32 + // ----- // CHECK-LABEL: func @propagate_into_execute_region