diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -116,13 +116,23 @@ cl::desc("If enabled, the freeze instruction will be added to condition " "of loop unswitch to prevent miscompilation.")); +// Helper to skip (select x, true, false), which matches both a logical AND and +// OR and can confuse code that tries to determine if \p Cond is either a +// logical AND or OR but not both. +static Value *skipTrivialSelect(Value *Cond) { + Value *CondNext; + while (match(Cond, m_Select(m_Value(CondNext), m_One(), m_Zero()))) + Cond = CondNext; + return Cond; +} + /// Collect all of the loop invariant input values transitively used by the /// homogeneous instruction graph from a given root. /// /// This essentially walks from a root recursively through loop variant operands -/// which have the exact same opcode and finds all inputs which are loop -/// invariant. For some operations these can be re-associated and unswitched out -/// of the loop entirely. +/// which have perform the same logical operation (AND or OR) and finds all +/// inputs which are loop invariant. For some operations these can be +/// re-associated and unswitched out of the loop entirely. static TinyPtrVector collectHomogenousInstGraphLoopInvariants(Loop &L, Instruction &Root, LoopInfo &LI) { @@ -152,7 +162,7 @@ } // If not an instruction with the same opcode, nothing we can do. - Instruction *OpI = dyn_cast(OpV); + Instruction *OpI = dyn_cast(skipTrivialSelect(OpV)); if (OpI && ((IsRootAnd && match(OpI, m_LogicalAnd())) || (IsRootOr && match(OpI, m_LogicalOr())))) { @@ -454,7 +464,8 @@ Invariants.push_back(BI.getCondition()); FullUnswitch = true; } else { - if (auto *CondInst = dyn_cast(BI.getCondition())) + if (auto *CondInst = + dyn_cast(skipTrivialSelect(BI.getCondition()))) Invariants = collectHomogenousInstGraphLoopInvariants(L, *CondInst, LI); if (Invariants.empty()) { LLVM_DEBUG(dbgs() << " Couldn't find invariant inputs!\n"); @@ -488,8 +499,9 @@ // is a graph of `or` operations, or the exit block is along the false edge // and the condition is a graph of `and` operations. if (!FullUnswitch) { - if (ExitDirection ? !match(BI.getCondition(), m_LogicalOr()) - : !match(BI.getCondition(), m_LogicalAnd())) { + Value *Cond = skipTrivialSelect(BI.getCondition()); + if (ExitDirection ? !match(Cond, m_LogicalOr()) + : !match(Cond, m_LogicalAnd())) { LLVM_DEBUG(dbgs() << " Branch condition is in improper form for " "non-full unswitch!\n"); return false; @@ -569,11 +581,11 @@ // Only unswitching a subset of inputs to the condition, so we will need to // build a new branch that merges the invariant inputs. if (ExitDirection) - assert(match(BI.getCondition(), m_LogicalOr()) && + assert(match(skipTrivialSelect(BI.getCondition()), m_LogicalOr()) && "Must have an `or` of `i1`s or `select i1 X, true, Y`s for the " "condition!"); else - assert(match(BI.getCondition(), m_LogicalAnd()) && + assert(match(skipTrivialSelect(BI.getCondition()), m_LogicalAnd()) && "Must have an `and` of `i1`s or `select i1 X, Y, false`s for the" " condition!"); buildPartialUnswitchConditionalBranch( @@ -2071,14 +2083,14 @@ bool Direction = true; int ClonedSucc = 0; if (!FullUnswitch) { - Value *Cond = BI->getCondition(); + Value *Cond = skipTrivialSelect(BI->getCondition()); (void)Cond; assert(((match(Cond, m_LogicalAnd()) ^ match(Cond, m_LogicalOr())) || PartiallyInvariant) && "Only `or`, `and`, an `select`, partially invariant instructions " "can combine invariants being unswitched."); - if (!match(BI->getCondition(), m_LogicalOr())) { - if (match(BI->getCondition(), m_LogicalAnd()) || + if (!match(Cond, m_LogicalOr())) { + if (match(Cond, m_LogicalAnd()) || (PartiallyInvariant && !PartialIVInfo.KnownValue->isOneValue())) { Direction = false; ClonedSucc = 1; @@ -2756,22 +2768,16 @@ BI->getSuccessor(0) == BI->getSuccessor(1)) continue; - // If BI's condition is 'select _, true, false', simplify it to confuse - // matchers - Value *Cond = BI->getCondition(), *CondNext; - while (match(Cond, m_Select(m_Value(CondNext), m_One(), m_Zero()))) - Cond = CondNext; - BI->setCondition(Cond); - + Value *Cond = skipTrivialSelect(BI->getCondition()); if (isa(Cond)) continue; - if (L.isLoopInvariant(BI->getCondition())) { - UnswitchCandidates.push_back({BI, {BI->getCondition()}}); + if (L.isLoopInvariant(Cond)) { + UnswitchCandidates.push_back({BI, {Cond}}); continue; } - Instruction &CondI = *cast(BI->getCondition()); + Instruction &CondI = *cast(Cond); if (match(&CondI, m_CombineOr(m_LogicalAnd(), m_LogicalOr()))) { TinyPtrVector Invariants = collectHomogenousInstGraphLoopInvariants(L, CondI, LI); @@ -2913,10 +2919,11 @@ // its cost. if (!FullUnswitch) { auto &BI = cast(TI); - if (match(BI.getCondition(), m_LogicalAnd())) { + Value *Cond = skipTrivialSelect(BI.getCondition()); + if (match(Cond, m_LogicalAnd())) { if (SuccBB == BI.getSuccessor(1)) continue; - } else if (match(BI.getCondition(), m_LogicalOr())) { + } else if (match(Cond, m_LogicalOr())) { if (SuccBB == BI.getSuccessor(0)) continue; } else if ((PartialIVInfo.KnownValue->isOneValue() && diff --git a/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-freeze.ll b/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-freeze.ll --- a/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-freeze.ll +++ b/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-freeze.ll @@ -2346,7 +2346,7 @@ ; CHECK: loop: ; CHECK-NEXT: [[TMP2:%.*]] = call i32 @a() ; CHECK-NEXT: [[SEL:%.*]] = select i1 true, i1 true, i1 false -; CHECK-NEXT: br i1 true, label [[LOOP]], label [[EXIT_SPLIT:%.*]] +; CHECK-NEXT: br i1 [[SEL]], label [[LOOP]], label [[EXIT_SPLIT:%.*]] ; CHECK: exit.split: ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: diff --git a/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-invariant-select-bug.ll b/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-invariant-select-bug.ll --- a/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-invariant-select-bug.ll +++ b/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-invariant-select-bug.ll @@ -16,7 +16,7 @@ ; CHECK-NEXT: br label [[HEADER:%.*]] ; CHECK: header: ; CHECK-NEXT: [[VAL:%.*]] = select i1 true, i1 true, i1 false -; CHECK-NEXT: br i1 true, label [[EXIT:%.*]], label [[HEADER]] +; CHECK-NEXT: br i1 [[VAL]], label [[EXIT:%.*]], label [[HEADER]] ; CHECK: exit: ; CHECK-NEXT: ret void ; diff --git a/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch.ll b/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch.ll --- a/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch.ll +++ b/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch.ll @@ -4277,7 +4277,7 @@ ; CHECK-NEXT: %[[V2:.*]] = load i1, i1* %ptr2 ; CHECK-NEXT: %[[AND1:.*]] = select i1 %[[V1]], i1 true, i1 false ; CHECK-NEXT: %[[AND2:.*]] = select i1 %[[AND1]], i1 true, i1 false -; CHECK-NEXT: br i1 %[[V1]], label %loop_a, label %loop_b +; CHECK-NEXT: br i1 %[[AND2]], label %loop_a, label %loop_b loop_a: call i32 @a() @@ -4357,7 +4357,7 @@ ; CHECK-NEXT: %[[V2:.*]] = load i1, i1* %ptr2 ; CHECK-NEXT: %[[AND1:.*]] = select i1 %[[V1]], i1 true, i1 false ; CHECK-NEXT: %[[AND2:.*]] = select i1 %[[AND1]], i1 true, i1 false -; CHECK-NEXT: br i1 %[[V1]], label %loop_b, label %loop_a +; CHECK-NEXT: br i1 %[[AND2]], label %loop_b, label %loop_a loop_a: call i32 @a() diff --git a/llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch-freeze-individual-conditions.ll b/llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch-freeze-individual-conditions.ll --- a/llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch-freeze-individual-conditions.ll +++ b/llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch-freeze-individual-conditions.ll @@ -11,14 +11,9 @@ ; CHECK-NEXT: [[TMP0:%.*]] = and i1 [[COND4_FR]], [[COND1_FR]] ; CHECK-NEXT: br i1 [[TMP0]], label [[ENTRY_SPLIT:%.*]], label [[EXIT_SPLIT:%.*]] ; CHECK: entry.split: -; CHECK-NEXT: [[COND2_FR:%.*]] = freeze i1 [[COND2:%.*]] -; CHECK-NEXT: [[COND3_FR:%.*]] = freeze i1 [[COND3:%.*]] -; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[COND2_FR]], [[COND3_FR]] -; CHECK-NEXT: br i1 [[TMP1]], label [[ENTRY_SPLIT_SPLIT:%.*]], label [[EXIT_SPLIT1:%.*]] -; CHECK: entry.split.split: ; CHECK-NEXT: br label [[LOOP_HEADER:%.*]] ; CHECK: loop.header: -; CHECK-NEXT: [[OR:%.*]] = or i1 true, true +; CHECK-NEXT: [[OR:%.*]] = or i1 [[COND2:%.*]], [[COND3:%.*]] ; CHECK-NEXT: [[AND1:%.*]] = and i1 [[OR]], true ; CHECK-NEXT: [[AND2:%.*]] = select i1 [[AND1]], i1 true, i1 false ; CHECK-NEXT: br i1 [[AND2]], label [[LOOP_LATCH:%.*]], label [[EXIT:%.*]] @@ -26,8 +21,6 @@ ; CHECK-NEXT: call void @some_func() ; CHECK-NEXT: br label [[LOOP_HEADER]] ; CHECK: exit: -; CHECK-NEXT: br label [[EXIT_SPLIT1]] -; CHECK: exit.split1: ; CHECK-NEXT: br label [[EXIT_SPLIT]] ; CHECK: exit.split: ; CHECK-NEXT: ret i32 0 @@ -56,14 +49,9 @@ ; CHECK-NEXT: [[TMP0:%.*]] = and i1 [[COND4_FR]], [[COND1:%.*]] ; CHECK-NEXT: br i1 [[TMP0]], label [[ENTRY_SPLIT:%.*]], label [[EXIT_SPLIT:%.*]] ; CHECK: entry.split: -; CHECK-NEXT: [[COND2_FR:%.*]] = freeze i1 [[COND2:%.*]] -; CHECK-NEXT: [[COND3_FR:%.*]] = freeze i1 [[COND3:%.*]] -; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[COND2_FR]], [[COND3_FR]] -; CHECK-NEXT: br i1 [[TMP1]], label [[ENTRY_SPLIT_SPLIT:%.*]], label [[EXIT_SPLIT1:%.*]] -; CHECK: entry.split.split: ; CHECK-NEXT: br label [[LOOP_HEADER:%.*]] ; CHECK: loop.header: -; CHECK-NEXT: [[OR:%.*]] = or i1 true, true +; CHECK-NEXT: [[OR:%.*]] = or i1 [[COND2:%.*]], [[COND3:%.*]] ; CHECK-NEXT: [[AND1:%.*]] = and i1 [[OR]], true ; CHECK-NEXT: [[AND2:%.*]] = select i1 [[AND1]], i1 true, i1 false ; CHECK-NEXT: br i1 [[AND2]], label [[LOOP_LATCH:%.*]], label [[EXIT:%.*]] @@ -71,8 +59,6 @@ ; CHECK-NEXT: call void @some_func() ; CHECK-NEXT: br label [[LOOP_HEADER]] ; CHECK: exit: -; CHECK-NEXT: br label [[EXIT_SPLIT1]] -; CHECK: exit.split1: ; CHECK-NEXT: br label [[EXIT_SPLIT]] ; CHECK: exit.split: ; CHECK-NEXT: ret i32 0 @@ -101,14 +87,9 @@ ; CHECK-NEXT: [[TMP0:%.*]] = and i1 [[COND4:%.*]], [[COND1_FR]] ; CHECK-NEXT: br i1 [[TMP0]], label [[ENTRY_SPLIT:%.*]], label [[EXIT_SPLIT:%.*]] ; CHECK: entry.split: -; CHECK-NEXT: [[COND2_FR:%.*]] = freeze i1 [[COND2:%.*]] -; CHECK-NEXT: [[COND3_FR:%.*]] = freeze i1 [[COND3:%.*]] -; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[COND2_FR]], [[COND3_FR]] -; CHECK-NEXT: br i1 [[TMP1]], label [[ENTRY_SPLIT_SPLIT:%.*]], label [[EXIT_SPLIT1:%.*]] -; CHECK: entry.split.split: ; CHECK-NEXT: br label [[LOOP_HEADER:%.*]] ; CHECK: loop.header: -; CHECK-NEXT: [[OR:%.*]] = or i1 true, true +; CHECK-NEXT: [[OR:%.*]] = or i1 [[COND2:%.*]], [[COND3:%.*]] ; CHECK-NEXT: [[AND1:%.*]] = and i1 [[OR]], true ; CHECK-NEXT: [[AND2:%.*]] = select i1 [[AND1]], i1 true, i1 false ; CHECK-NEXT: br i1 [[AND2]], label [[LOOP_LATCH:%.*]], label [[EXIT:%.*]] @@ -116,8 +97,6 @@ ; CHECK-NEXT: call void @some_func() ; CHECK-NEXT: br label [[LOOP_HEADER]] ; CHECK: exit: -; CHECK-NEXT: br label [[EXIT_SPLIT1]] -; CHECK: exit.split1: ; CHECK-NEXT: br label [[EXIT_SPLIT]] ; CHECK: exit.split: ; CHECK-NEXT: ret i32 0 diff --git a/llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch-logical-and-or.ll b/llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch-logical-and-or.ll --- a/llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch-logical-and-or.ll +++ b/llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch-logical-and-or.ll @@ -8,20 +8,15 @@ define void @test_select_logical_and_or_with_and_1(i1 noundef %cond1, i1 noundef %cond2) { ; CHECK-LABEL: @test_select_logical_and_or_with_and_1( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = or i1 [[COND2:%.*]], [[COND1:%.*]] -; CHECK-NEXT: br i1 [[TMP0]], label [[EXIT_SPLIT:%.*]], label [[ENTRY_SPLIT:%.*]] -; CHECK: entry.split: ; CHECK-NEXT: br label [[LOOP_HEADER:%.*]] ; CHECK: loop.header: -; CHECK-NEXT: [[COND_AND1:%.*]] = and i1 false, false +; CHECK-NEXT: [[COND_AND1:%.*]] = and i1 [[COND2:%.*]], [[COND1:%.*]] ; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND_AND1]], i1 true, i1 false ; CHECK-NEXT: br i1 [[SEL]], label [[EXIT:%.*]], label [[LOOP_LATCH:%.*]] ; CHECK: loop.latch: ; CHECK-NEXT: call void @some_func() ; CHECK-NEXT: br label [[LOOP_HEADER]] ; CHECK: exit: -; CHECK-NEXT: br label [[EXIT_SPLIT]] -; CHECK: exit.split: ; CHECK-NEXT: ret void ; entry: @@ -114,20 +109,15 @@ define void @test_select_logical_and_or_with_or_2(i1 noundef %cond1, i1 noundef %cond2) { ; CHECK-LABEL: @test_select_logical_and_or_with_or_2( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = and i1 [[COND2:%.*]], [[COND1:%.*]] -; CHECK-NEXT: br i1 [[TMP0]], label [[ENTRY_SPLIT:%.*]], label [[EXIT_SPLIT:%.*]] -; CHECK: entry.split: ; CHECK-NEXT: br label [[LOOP_HEADER:%.*]] ; CHECK: loop.header: -; CHECK-NEXT: [[COND_AND1:%.*]] = or i1 true, true +; CHECK-NEXT: [[COND_AND1:%.*]] = or i1 [[COND2:%.*]], [[COND1:%.*]] ; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND_AND1]], i1 true, i1 false ; CHECK-NEXT: br i1 [[SEL]], label [[LOOP_LATCH:%.*]], label [[EXIT:%.*]] ; CHECK: loop.latch: ; CHECK-NEXT: call void @some_func() ; CHECK-NEXT: br label [[LOOP_HEADER]] ; CHECK: exit: -; CHECK-NEXT: br label [[EXIT_SPLIT]] -; CHECK: exit.split: ; CHECK-NEXT: ret void ; entry: @@ -147,8 +137,6 @@ } ; Check that loop unswitch looks through a combination of or and select instructions. -; Note that cond6 can be unswitched because `select i1 %cond_or5, i1 true, i1 false` is -; both logical-or and logical-and. define i32 @test_partial_condition_unswitch_or_select(i32* %var, i1 %cond1, i1 %cond2, i1 %cond3, i1 %cond4, i1 %cond5, i1 %cond6) { ; CHECK-LABEL: @test_partial_condition_unswitch_or_select( ; CHECK-NEXT: entry: @@ -161,9 +149,6 @@ ; CHECK-NEXT: [[TMP2:%.*]] = or i1 [[TMP1]], [[COND1_FR]] ; CHECK-NEXT: br i1 [[TMP2]], label [[LOOP_EXIT_SPLIT:%.*]], label [[ENTRY_SPLIT:%.*]] ; CHECK: entry.split: -; CHECK-NEXT: [[COND6_FR:%.*]] = freeze i1 [[COND6:%.*]] -; CHECK-NEXT: br i1 [[COND6_FR]], label [[LOOP_EXIT_SPLIT1:%.*]], label [[ENTRY_SPLIT_SPLIT:%.*]] -; CHECK: entry.split.split: ; CHECK-NEXT: br label [[LOOP_BEGIN:%.*]] ; CHECK: loop_begin: ; CHECK-NEXT: [[VAR_VAL:%.*]] = load i32, i32* [[VAR:%.*]], align 4 @@ -172,7 +157,7 @@ ; CHECK-NEXT: [[COND_OR2:%.*]] = or i1 false, false ; CHECK-NEXT: [[COND_OR3:%.*]] = or i1 [[COND_OR1]], [[COND_OR2]] ; CHECK-NEXT: [[COND_XOR1:%.*]] = xor i1 [[COND5:%.*]], [[VAR_COND]] -; CHECK-NEXT: [[COND_AND1:%.*]] = and i1 false, [[VAR_COND]] +; CHECK-NEXT: [[COND_AND1:%.*]] = and i1 [[COND6:%.*]], [[VAR_COND]] ; CHECK-NEXT: [[COND_OR4:%.*]] = or i1 [[COND_XOR1]], [[COND_AND1]] ; CHECK-NEXT: [[COND_OR5:%.*]] = select i1 [[COND_OR3]], i1 true, i1 [[COND_OR4]] ; CHECK-NEXT: [[COND_OR6:%.*]] = select i1 [[COND_OR5]], i1 true, i1 false @@ -181,8 +166,6 @@ ; CHECK-NEXT: call void @some_func() #[[ATTR0:[0-9]+]] ; CHECK-NEXT: br label [[LOOP_BEGIN]] ; CHECK: loop_exit: -; CHECK-NEXT: br label [[LOOP_EXIT_SPLIT1]] -; CHECK: loop_exit.split1: ; CHECK-NEXT: br label [[LOOP_EXIT_SPLIT]] ; CHECK: loop_exit.split: ; CHECK-NEXT: ret i32 0 @@ -211,6 +194,8 @@ ret i32 0 } +; Same as test_partial_condition_unswitch_or_select, but with arguments marked +; as noundef. define i32 @test_partial_condition_unswitch_or_select_noundef(i32* noundef %var, i1 noundef %cond1, i1 noundef %cond2, i1 noundef %cond3, i1 noundef %cond4, i1 noundef %cond5, i1 noundef %cond6) { ; CHECK-LABEL: @test_partial_condition_unswitch_or_select_noundef( ; CHECK-NEXT: entry: @@ -219,8 +204,6 @@ ; CHECK-NEXT: [[TMP2:%.*]] = or i1 [[TMP1]], [[COND1:%.*]] ; CHECK-NEXT: br i1 [[TMP2]], label [[LOOP_EXIT_SPLIT:%.*]], label [[ENTRY_SPLIT:%.*]] ; CHECK: entry.split: -; CHECK-NEXT: br i1 [[COND6:%.*]], label [[LOOP_EXIT_SPLIT1:%.*]], label [[ENTRY_SPLIT_SPLIT:%.*]] -; CHECK: entry.split.split: ; CHECK-NEXT: br label [[LOOP_BEGIN:%.*]] ; CHECK: loop_begin: ; CHECK-NEXT: [[VAR_VAL:%.*]] = load i32, i32* [[VAR:%.*]], align 4 @@ -229,7 +212,7 @@ ; CHECK-NEXT: [[COND_OR2:%.*]] = or i1 false, false ; CHECK-NEXT: [[COND_OR3:%.*]] = or i1 [[COND_OR1]], [[COND_OR2]] ; CHECK-NEXT: [[COND_XOR1:%.*]] = xor i1 [[COND5:%.*]], [[VAR_COND]] -; CHECK-NEXT: [[COND_AND1:%.*]] = and i1 false, [[VAR_COND]] +; CHECK-NEXT: [[COND_AND1:%.*]] = and i1 [[COND6:%.*]], [[VAR_COND]] ; CHECK-NEXT: [[COND_OR4:%.*]] = or i1 [[COND_XOR1]], [[COND_AND1]] ; CHECK-NEXT: [[COND_OR5:%.*]] = select i1 [[COND_OR3]], i1 true, i1 [[COND_OR4]] ; CHECK-NEXT: [[COND_OR6:%.*]] = select i1 [[COND_OR5]], i1 true, i1 false @@ -238,8 +221,6 @@ ; CHECK-NEXT: call void @some_func() #[[ATTR0]] ; CHECK-NEXT: br label [[LOOP_BEGIN]] ; CHECK: loop_exit: -; CHECK-NEXT: br label [[LOOP_EXIT_SPLIT1]] -; CHECK: loop_exit.split1: ; CHECK-NEXT: br label [[LOOP_EXIT_SPLIT]] ; CHECK: loop_exit.split: ; CHECK-NEXT: ret i32 0