diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -28,6 +28,7 @@ #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/MustExecute.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -109,6 +110,12 @@ cl::desc("Max number of memory uses to explore during " "partial unswitching analysis"), cl::init(100), cl::Hidden); +static cl::opt FreezeLoopUnswitchCond( + "freeze-loop-unswitch-cond", + cl::init(false), cl::Hidden, + cl::desc("If enabled, the freeze instruction will be added to condition " + "of loop unswitch to prevent miscompilation.")); + /// Collect all of the loop invariant input values transitively used by the /// homogeneous instruction graph from a given root. @@ -200,11 +207,14 @@ ArrayRef Invariants, bool Direction, BasicBlock &UnswitchedSucc, - BasicBlock &NormalSucc) { + BasicBlock &NormalSucc, + bool InsertFreeze) { IRBuilder<> IRB(&BB); Value *Cond = Direction ? IRB.CreateOr(Invariants) : IRB.CreateAnd(Invariants); + if (InsertFreeze) + Cond = IRB.CreateFreeze(Cond, Cond->getName() + ".fr"); IRB.CreateCondBr(Cond, Direction ? &UnswitchedSucc : &NormalSucc, Direction ? &NormalSucc : &UnswitchedSucc); } @@ -565,7 +575,7 @@ "Must have an `and` of `i1`s or `select i1 X, Y, false`s for the" " condition!"); buildPartialUnswitchConditionalBranch(*OldPH, Invariants, ExitDirection, - *UnswitchedBB, *NewPH); + *UnswitchedBB, *NewPH, false); } // Update the dominator tree with the added edge. @@ -2124,6 +2134,13 @@ SE->forgetTopmostLoop(&L); } + bool InsertFreeze = false; + if (FreezeLoopUnswitchCond) { + ICFLoopSafetyInfo SafetyInfo; + SafetyInfo.computeLoopSafetyInfo(&L); + InsertFreeze = !SafetyInfo.isGuaranteedToExecute(TI, &DT, &L); + } + // If the edge from this terminator to a successor dominates that successor, // store a map from each block in its dominator subtree to it. This lets us // tell when cloning for a particular successor if a block is dominated by @@ -2198,6 +2215,28 @@ BasicBlock *ClonedPH = ClonedPHs.begin()->second; BI->setSuccessor(ClonedSucc, ClonedPH); BI->setSuccessor(1 - ClonedSucc, LoopPH); + if (InsertFreeze) { + auto Cond = BI->getCondition(); + if (!isGuaranteedNotToBeUndefOrPoison(Cond, &AC, BI, &DT)) { + auto FrozenCond = new FreezeInst(Cond, Cond->getName() + ".fr"); + if (dyn_cast(Cond)) { + if (PHINode *PN = dyn_cast(Cond)) + FrozenCond->insertAfter(PN->getParent()->getFirstNonPHI()); + else if (InvokeInst *II = dyn_cast(Cond)) { + auto *DestBB = dyn_cast(II->getOperand(1)); + FrozenCond->insertAfter(DestBB->getFirstNonPHI()); + } else + FrozenCond->insertAfter(dyn_cast(Cond)); + + Cond->replaceUsesWithIf(FrozenCond, [](Use &U) { + return !isa(U.getUser()); + }); + } else { + FrozenCond->insertBefore(BI); + BI->setCondition(FrozenCond); + } + } + } DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); } else { assert(SI && "Must either be a branch or switch!"); @@ -2212,6 +2251,28 @@ else Case.setSuccessor(ClonedPHs.find(Case.getCaseSuccessor())->second); + if (InsertFreeze) { + auto Cond = SI->getCondition(); + if (!isGuaranteedNotToBeUndefOrPoison(Cond, &AC, SI, &DT)) { + auto FrozenCond = new FreezeInst(Cond, Cond->getName() + ".fr"); + if (dyn_cast(Cond)) { + if (PHINode *PN = dyn_cast(Cond)) + FrozenCond->insertAfter(PN->getParent()->getFirstNonPHI()); + else if (InvokeInst *II = dyn_cast(Cond)) { + auto *DestBB = dyn_cast(II->getOperand(1)); + FrozenCond->insertAfter(DestBB->getFirstNonPHI()); + } else + FrozenCond->insertAfter(dyn_cast(Cond)); + + Cond->replaceUsesWithIf(FrozenCond, [](Use &U) { + return !isa(U.getUser()); + }); + } else { + FrozenCond->insertBefore(SI); + SI->setCondition(FrozenCond); + } + } + } // We need to use the set to populate domtree updates as even when there // are multiple cases pointing at the same successor we only want to // remove and insert one edge in the domtree. @@ -2292,7 +2353,7 @@ *SplitBB, Invariants, Direction, *ClonedPH, *LoopPH, L, MSSAU); else buildPartialUnswitchConditionalBranch(*SplitBB, Invariants, Direction, - *ClonedPH, *LoopPH); + *ClonedPH, *LoopPH, InsertFreeze); DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); if (MSSAU) { @@ -2376,17 +2437,28 @@ "Should not be replacing constant values!"); // Use make_early_inc_range here as set invalidates the iterator. for (Use &U : llvm::make_early_inc_range(Invariant->uses())) { - Instruction *UserI = dyn_cast(U.getUser()); - if (!UserI) - continue; - - // Replace it with the 'continue' side if in the main loop body, and the - // unswitched if in the cloned blocks. - if (DT.dominates(LoopPH, UserI->getParent())) - U.set(ContinueReplacement); - else if (ReplaceUnswitched && - DT.dominates(ClonedPH, UserI->getParent())) - U.set(UnswitchedReplacement); + auto ReplaceIfDominated = [&](Use &U) { + Instruction *UserI = dyn_cast(U.getUser()); + if (!UserI) + return; + + // Replace it with the 'continue' side if in the main loop body, and + // the unswitched if in the cloned blocks. + if (DT.dominates(LoopPH, UserI->getParent())) + U.set(ContinueReplacement); + else if (ReplaceUnswitched && + DT.dominates(ClonedPH, UserI->getParent())) + U.set(UnswitchedReplacement); + }; + + ReplaceIfDominated(U); + + // If V is invariant, Freeze(V) is also invariant. As we try to replace + // the use of V to constant, we need to try replace the use of Freeze(V) + // to constant. + if (auto *FI = dyn_cast(U.getUser())) + for (Use &UU : llvm::make_early_inc_range(FI->uses())) + ReplaceIfDominated(UU); } } } diff --git a/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch.ll b/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch.ll --- a/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch.ll +++ b/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch.ll @@ -9,6 +9,8 @@ declare void @sink1(i32) declare void @sink2(i32) +declare void @sink3(i1) +declare void @sink4(i1) declare i1 @cond() declare i32 @cond.i32() @@ -231,6 +233,117 @@ ; CHECK-NEXT: ret } +define i32 @test1_freeze(i1* %ptr0, i1* %ptr1, i1* %ptr2) { +; CHECK-LABEL: @test1_freeze( +entry: + %cond1 = load i1, i1* %ptr1 + %cond2 = load i1, i1* %ptr2 + br label %loop_begin +; CHECK-NEXT: entry: +; CHECK-NEXT: %cond1 = load i1, i1* %ptr1, align 1 +; CHECK-NEXT: %cond2 = load i1, i1* %ptr2, align 1 +; CHECK-NEXT: br i1 %cond1, label %entry.split.us, label %entry.split + +loop_begin: + br i1 %cond1, label %loop_a, label %loop_b + +loop_a: + call i32 @a() + br label %latch +; The 'loop_a' unswitched loop. +; +; CHECK: entry.split.us: +; CHECK-NEXT: br label %loop_begin.us +; +; CHECK: loop_begin.us: +; CHECK-NEXT: br label %loop_a.us +; +; CHECK: loop_a.us: +; CHECK-NEXT: %0 = call i32 @a() +; CHECK-NEXT: br label %latch.us +; +; CHECK: latch.us: +; CHECK-NEXT: %[[V:.*]] = load i1, i1* %ptr +; CHECK-NEXT: br i1 %[[V]], label %loop_begin.us, label %loop_exit.split.us +; +; CHECK: loop_exit.split.us: +; CHECK-NEXT: br label %loop_exit + +loop_b: + call i32 @b() + br i1 %cond2, label %loop_b_a, label %loop_b_b +; The second unswitched condition. +; +; CHECK: entry.split: +; CHECK-NEXT: br i1 %cond2, label %entry.split.split.us, label %entry.split.split + +loop_b_a: + call void @sink3(i1 %cond2) + br label %latch +; The 'loop_b_a' unswitched loop. +; %cond2 is replaced to true +; +; CHECK: entry.split.split.us: +; CHECK-NEXT: br label %loop_begin.us1 +; +; CHECK: loop_begin.us1: +; CHECK-NEXT: br label %loop_b.us +; +; CHECK: loop_b.us: +; CHECK-NEXT: %1 = call i32 @b() +; CHECK-NEXT: br label %loop_b_a.us +; +; CHECK: loop_b_a.us: +; CHECK-NEXT: call void @sink3(i1 true) +; CHECK-NEXT: br label %latch.us2 +; +; CHECK: latch.us2: +; CHECK-NEXT: %[[V:.*]] = load i1, i1* %ptr +; CHECK-NEXT: br i1 %[[V]], label %loop_begin.us1, label %loop_exit.split.split.us +; +; CHECK: loop_exit.split.split.us: +; CHECK-NEXT: br label %loop_exit.split + +loop_b_b: + call void @sink4(i1 %cond2) + br label %latch +; The 'loop_b_b' unswitched loop. +; %cond2 is replaced to false +; +; CHECK: entry.split.split: +; CHECK-NEXT: br label %loop_begin +; +; CHECK: loop_begin: +; CHECK-NEXT: br label %loop_b +; +; CHECK: loop_b: +; CHECK-NEXT: %2 = call i32 @b() +; CHECK-NEXT: br label %loop_b_b +; +; CHECK: loop_b_b: +; CHECK-NEXT: call void @sink4(i1 false) +; CHECK-NEXT: br label %latch +; +; CHECK: latch: +; CHECK-NEXT: %[[V:.*]] = load i1, i1* %ptr +; CHECK-NEXT: br i1 %[[V]], label %loop_begin, label %loop_exit.split.split +; +; CHECK: loop_exit.split.split: +; CHECK-NEXT: br label %loop_exit.split + +latch: + %v = load i1, i1* %ptr0 + br i1 %v, label %loop_begin, label %loop_exit + +loop_exit: + ret i32 0 +; CHECK: loop_exit.split: +; CHECK-NEXT: br label %loop_exit +; +; CHECK: loop_exit: +; CHECK-NEXT: ret +} + define i32 @test2(i1* %ptr, i1 %cond1, i32* %a.ptr, i32* %b.ptr, i32* %c.ptr) { ; CHECK-LABEL: @test2( entry: