Index: lib/Transforms/Scalar/SimpleLoopUnswitch.cpp =================================================================== --- lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -26,6 +26,7 @@ #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -47,6 +48,7 @@ #include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include @@ -364,7 +366,9 @@ // some input conditions to the branch. bool FullUnswitch = false; - if (L.isLoopInvariant(BI.getCondition())) { + if (L.isLoopInvariant(BI.getCondition()) || + (SE->isLoopInvariant(SE->getSCEV(BI.getCondition()), &L) && + isSafeToExpand(SE->getSCEV(BI.getCondition()), *SE))) { Invariants.push_back(BI.getCondition()); FullUnswitch = true; } else { @@ -417,6 +421,28 @@ } }); + if (!L.isLoopInvariant(BI.getCondition()) && + SE->isLoopInvariant(SE->getSCEV(BI.getCondition()), &L)) { + auto &DL = LoopExitBB->getParent()->getParent()->getDataLayout(); + auto *PHTerm = L.getLoopPreheader()->getTerminator(); + // TODO: Do we need a SCEVHoister? + SCEVExpander Expander(*SE, DL, ""); + auto *OldCond = BI.getCondition(); + auto *NewCond = Expander.expandCodeFor(SE->getSCEV(OldCond), + OldCond->getType(), + PHTerm); + BI.setCondition(NewCond); + assert(Invariants.size() == 1 && + Invariants[0] == OldCond); + TinyPtrVector NewInvariants; + NewInvariants.push_back(NewCond); + Invariants = NewInvariants; + + // FIXME: This can walk out of our current loop. + RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr, MSSAU); + } + + // If we have scalar evolutions, we need to invalidate them including this // loop and the loop containing the exit block. if (SE) { Index: test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll =================================================================== --- test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll +++ test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll @@ -1243,3 +1243,49 @@ ; CHECK: loopexit: ; CHECK-NEXT: ret } + +define i32 @test_iterative_scev_hoist(i32* %var, i32 %cond1.i32, + i32 %cond2.i32) { +; CHECK-LABEL: @test_iterative_scev_hoist( +entry: + br label %loop_begin +; CHECK-NEXT: entry: +; CHECK-NEXT: trunc +; CHECK-NEXT: br i1 %{{.*}}, label %entry.split, label %loop_exit.split +; +; CHECK: entry.split: +; CHECK-NEXT: trunc +; CHECK-NEXT: br i1 %{{.*}}, label %entry.split.split, label %loop_exit +; +; CHECK: entry.split.split: +; CHECK-NEXT: br label %loop_begin + +loop_begin: + %cond1 = trunc i32 %cond1.i32 to i1 + br i1 %cond1, label %continue, label %loop_exit ; first trivial condition +; CHECK: loop_begin: +; CHECK-NEXT: br label %continue + +continue: + %var_val = load i32, i32* %var + %cond2 = trunc i32 %cond2.i32 to i1 + br i1 %cond2, label %do_something, label %loop_exit ; second trivial condition +; CHECK: continue: +; CHECK-NEXT: load +; CHECK-NEXT: br label %do_something + +do_something: + call void @some_func() noreturn nounwind + br label %loop_begin +; CHECK: do_something: +; CHECK-NEXT: call +; CHECK-NEXT: br label %loop_begin + +loop_exit: + ret i32 0 +; CHECK: loop_exit: +; CHECK-NEXT: br label %loop_exit.split +; +; CHECK: loop_exit.split: +; CHECK-NEXT: ret +}