diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -120,12 +120,14 @@ /// homogeneous instruction graph from a given root. /// /// This essentially walks from a root recursively through loop variant operands -/// which have the exact same opcode and finds all inputs which are loop -/// invariant. For some operations these can be re-associated and unswitched out -/// of the loop entirely. -static TinyPtrVector -collectHomogenousInstGraphLoopInvariants(Loop &L, Instruction &Root, - LoopInfo &LI) { +/// which have perform the same logical operation (AND or OR) and finds all +/// inputs which are loop invariant. For some operations these can be +/// re-associated and unswitched out of the loop entirely. +/// +/// If \p MaybeNeedsOr is provided, limit the operations to match to either +/// logical ORs (if it is true) or logical ANDs (if it is false). +static TinyPtrVector collectHomogenousInstGraphLoopInvariants( + Loop &L, Instruction &Root, Optional MaybeNeedsOr, LoopInfo &LI) { assert(!L.isLoopInvariant(&Root) && "Only need to walk the graph if root itself is not invariant."); TinyPtrVector Invariants; @@ -133,6 +135,11 @@ bool IsRootAnd = match(&Root, m_LogicalAnd()); bool IsRootOr = match(&Root, m_LogicalOr()); + if (MaybeNeedsOr) { + IsRootAnd &= !*MaybeNeedsOr; + IsRootOr &= *MaybeNeedsOr; + } + // Build a worklist and recurse through operators collecting invariants. SmallVector Worklist; SmallPtrSet Visited; @@ -437,25 +444,6 @@ assert(BI.isConditional() && "Can only unswitch a conditional branch!"); LLVM_DEBUG(dbgs() << " Trying to unswitch branch: " << BI << "\n"); - // The loop invariant values that we want to unswitch. - TinyPtrVector Invariants; - - // When true, we're fully unswitching the branch rather than just unswitching - // some input conditions to the branch. - bool FullUnswitch = false; - - if (L.isLoopInvariant(BI.getCondition())) { - Invariants.push_back(BI.getCondition()); - FullUnswitch = true; - } else { - if (auto *CondInst = dyn_cast(BI.getCondition())) - Invariants = collectHomogenousInstGraphLoopInvariants(L, *CondInst, LI); - if (Invariants.empty()) { - LLVM_DEBUG(dbgs() << " Couldn't find invariant inputs!\n"); - return false; - } - } - // Check that one of the branch's successors exits, and which one. bool ExitDirection = true; int LoopExitSuccIdx = 0; @@ -476,6 +464,26 @@ return false; } + // The loop invariant values that we want to unswitch. + TinyPtrVector Invariants; + + // When true, we're fully unswitching the branch rather than just unswitching + // some input conditions to the branch. + bool FullUnswitch = false; + + if (L.isLoopInvariant(BI.getCondition())) { + Invariants.push_back(BI.getCondition()); + FullUnswitch = true; + } else { + if (auto *CondInst = dyn_cast(BI.getCondition())) + Invariants = collectHomogenousInstGraphLoopInvariants( + L, *CondInst, {ExitDirection}, LI); + if (Invariants.empty()) { + LLVM_DEBUG(dbgs() << " Couldn't find invariant inputs!\n"); + return false; + } + } + // When unswitching only part of the branch's condition, we need the exit // block to be reached directly from the partially unswitched input. This can // be done when the exit block is along the true edge and the branch condition @@ -2769,7 +2777,7 @@ Instruction &CondI = *cast(BI->getCondition()); if (match(&CondI, m_CombineOr(m_LogicalAnd(), m_LogicalOr()))) { TinyPtrVector Invariants = - collectHomogenousInstGraphLoopInvariants(L, CondI, LI); + collectHomogenousInstGraphLoopInvariants(L, CondI, None, LI); if (Invariants.empty()) continue; diff --git a/llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch-logical-and-or.ll b/llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch-logical-and-or.ll --- a/llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch-logical-and-or.ll +++ b/llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch-logical-and-or.ll @@ -8,20 +8,15 @@ define void @test_select_logical_and_or_with_and_1(i1 noundef %cond1, i1 noundef %cond2) { ; CHECK-LABEL: @test_select_logical_and_or_with_and_1( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = or i1 [[COND2:%.*]], [[COND1:%.*]] -; CHECK-NEXT: br i1 [[TMP0]], label [[EXIT_SPLIT:%.*]], label [[ENTRY_SPLIT:%.*]] -; CHECK: entry.split: ; CHECK-NEXT: br label [[LOOP_HEADER:%.*]] ; CHECK: loop.header: -; CHECK-NEXT: [[COND_AND1:%.*]] = and i1 false, false +; CHECK-NEXT: [[COND_AND1:%.*]] = and i1 [[COND2:%.*]], [[COND1:%.*]] ; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND_AND1]], i1 true, i1 false ; CHECK-NEXT: br i1 [[SEL]], label [[EXIT:%.*]], label [[LOOP_LATCH:%.*]] ; CHECK: loop.latch: ; CHECK-NEXT: call void @some_func() ; CHECK-NEXT: br label [[LOOP_HEADER]] ; CHECK: exit: -; CHECK-NEXT: br label [[EXIT_SPLIT]] -; CHECK: exit.split: ; CHECK-NEXT: ret void ; entry: @@ -114,20 +109,15 @@ define void @test_select_logical_and_or_with_or_2(i1 noundef %cond1, i1 noundef %cond2) { ; CHECK-LABEL: @test_select_logical_and_or_with_or_2( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = and i1 [[COND2:%.*]], [[COND1:%.*]] -; CHECK-NEXT: br i1 [[TMP0]], label [[ENTRY_SPLIT:%.*]], label [[EXIT_SPLIT:%.*]] -; CHECK: entry.split: ; CHECK-NEXT: br label [[LOOP_HEADER:%.*]] ; CHECK: loop.header: -; CHECK-NEXT: [[COND_AND1:%.*]] = or i1 true, true +; CHECK-NEXT: [[COND_AND1:%.*]] = or i1 [[COND2:%.*]], [[COND1:%.*]] ; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND_AND1]], i1 true, i1 false ; CHECK-NEXT: br i1 [[SEL]], label [[LOOP_LATCH:%.*]], label [[EXIT:%.*]] ; CHECK: loop.latch: ; CHECK-NEXT: call void @some_func() ; CHECK-NEXT: br label [[LOOP_HEADER]] ; CHECK: exit: -; CHECK-NEXT: br label [[EXIT_SPLIT]] -; CHECK: exit.split: ; CHECK-NEXT: ret void ; entry: @@ -147,8 +137,6 @@ } ; Check that loop unswitch looks through a combination of or and select instructions. -; Note that cond6 can be unswitched because `select i1 %cond_or5, i1 true, i1 false` is -; both logical-or and logical-and. define i32 @test_partial_condition_unswitch_or_select(i32* %var, i1 %cond1, i1 %cond2, i1 %cond3, i1 %cond4, i1 %cond5, i1 %cond6) { ; CHECK-LABEL: @test_partial_condition_unswitch_or_select( ; CHECK-NEXT: entry: @@ -157,8 +145,6 @@ ; CHECK-NEXT: [[TMP2:%.*]] = or i1 [[TMP1]], [[COND1:%.*]] ; CHECK-NEXT: br i1 [[TMP2]], label [[LOOP_EXIT_SPLIT:%.*]], label [[ENTRY_SPLIT:%.*]] ; CHECK: entry.split: -; CHECK-NEXT: br i1 [[COND6:%.*]], label [[LOOP_EXIT_SPLIT1:%.*]], label [[ENTRY_SPLIT_SPLIT:%.*]] -; CHECK: entry.split.split: ; CHECK-NEXT: br label [[LOOP_BEGIN:%.*]] ; CHECK: loop_begin: ; CHECK-NEXT: [[VAR_VAL:%.*]] = load i32, i32* [[VAR:%.*]], align 4 @@ -167,7 +153,7 @@ ; CHECK-NEXT: [[COND_OR2:%.*]] = or i1 false, false ; CHECK-NEXT: [[COND_OR3:%.*]] = or i1 [[COND_OR1]], [[COND_OR2]] ; CHECK-NEXT: [[COND_XOR1:%.*]] = xor i1 [[COND5:%.*]], [[VAR_COND]] -; CHECK-NEXT: [[COND_AND1:%.*]] = and i1 false, [[VAR_COND]] +; CHECK-NEXT: [[COND_AND1:%.*]] = and i1 [[COND6:%.*]], [[VAR_COND]] ; CHECK-NEXT: [[COND_OR4:%.*]] = or i1 [[COND_XOR1]], [[COND_AND1]] ; CHECK-NEXT: [[COND_OR5:%.*]] = select i1 [[COND_OR3]], i1 true, i1 [[COND_OR4]] ; CHECK-NEXT: [[COND_OR6:%.*]] = select i1 [[COND_OR5]], i1 true, i1 false @@ -176,8 +162,6 @@ ; CHECK-NEXT: call void @some_func() #[[ATTR0:[0-9]+]] ; CHECK-NEXT: br label [[LOOP_BEGIN]] ; CHECK: loop_exit: -; CHECK-NEXT: br label [[LOOP_EXIT_SPLIT1]] -; CHECK: loop_exit.split1: ; CHECK-NEXT: br label [[LOOP_EXIT_SPLIT]] ; CHECK: loop_exit.split: ; CHECK-NEXT: ret i32 0 @@ -206,6 +190,8 @@ ret i32 0 } +; Same as test_partial_condition_unswitch_or_select, but with arguments marked +; as noundef. define i32 @test_partial_condition_unswitch_or_select_noundef(i32* noundef %var, i1 noundef %cond1, i1 noundef %cond2, i1 noundef %cond3, i1 noundef %cond4, i1 noundef %cond5, i1 noundef %cond6) { ; CHECK-LABEL: @test_partial_condition_unswitch_or_select_noundef( ; CHECK-NEXT: entry: @@ -214,8 +200,6 @@ ; CHECK-NEXT: [[TMP2:%.*]] = or i1 [[TMP1]], [[COND1:%.*]] ; CHECK-NEXT: br i1 [[TMP2]], label [[LOOP_EXIT_SPLIT:%.*]], label [[ENTRY_SPLIT:%.*]] ; CHECK: entry.split: -; CHECK-NEXT: br i1 [[COND6:%.*]], label [[LOOP_EXIT_SPLIT1:%.*]], label [[ENTRY_SPLIT_SPLIT:%.*]] -; CHECK: entry.split.split: ; CHECK-NEXT: br label [[LOOP_BEGIN:%.*]] ; CHECK: loop_begin: ; CHECK-NEXT: [[VAR_VAL:%.*]] = load i32, i32* [[VAR:%.*]], align 4 @@ -224,7 +208,7 @@ ; CHECK-NEXT: [[COND_OR2:%.*]] = or i1 false, false ; CHECK-NEXT: [[COND_OR3:%.*]] = or i1 [[COND_OR1]], [[COND_OR2]] ; CHECK-NEXT: [[COND_XOR1:%.*]] = xor i1 [[COND5:%.*]], [[VAR_COND]] -; CHECK-NEXT: [[COND_AND1:%.*]] = and i1 false, [[VAR_COND]] +; CHECK-NEXT: [[COND_AND1:%.*]] = and i1 [[COND6:%.*]], [[VAR_COND]] ; CHECK-NEXT: [[COND_OR4:%.*]] = or i1 [[COND_XOR1]], [[COND_AND1]] ; CHECK-NEXT: [[COND_OR5:%.*]] = select i1 [[COND_OR3]], i1 true, i1 [[COND_OR4]] ; CHECK-NEXT: [[COND_OR6:%.*]] = select i1 [[COND_OR5]], i1 true, i1 false @@ -233,8 +217,6 @@ ; CHECK-NEXT: call void @some_func() #[[ATTR0]] ; CHECK-NEXT: br label [[LOOP_BEGIN]] ; CHECK: loop_exit: -; CHECK-NEXT: br label [[LOOP_EXIT_SPLIT1]] -; CHECK: loop_exit.split1: ; CHECK-NEXT: br label [[LOOP_EXIT_SPLIT]] ; CHECK: loop_exit.split: ; CHECK-NEXT: ret i32 0