Index: lib/Transforms/Scalar/LoopUnswitch.cpp =================================================================== --- lib/Transforms/Scalar/LoopUnswitch.cpp +++ lib/Transforms/Scalar/LoopUnswitch.cpp @@ -368,9 +368,27 @@ return new LoopUnswitch(Os); } +/// Operator chain lattice. +enum OperatorChain { + OC_OpChainNone, ///< There is no operator. + OC_OpChainOr, ///< There are only ORs. + OC_OpChainAnd, ///< There are only ANDs. + OC_OpChainMixed ///< There are ANDs and ORs. +}; + /// Cond is a condition that occurs in L. If it is invariant in the loop, or has /// an invariant piece, return the invariant. Otherwise, return null. +// +/// NOTE: FindLIVLoopCondition will not return a partial LIV by walking up a +/// mixed operator chain, as we can not reliably find a value which will simplify +/// the operator chain. If the chain is AND-only or OR-only, we can use 0 or ~0 +/// to simplify the chain. +/// +/// NOTE: In case a partial LIV and a mixed operator chain, we may be able to +/// simplify the condition itself to a loop variant condition, but at the +/// cost of creating an entirely new loop. static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, + OperatorChain &ParentChain, DenseMap &Cache) { auto CacheIt = Cache.find(Cond); if (CacheIt != Cache.end()) @@ -394,21 +412,53 @@ return Cond; } + // Walk up the operator chain to find partial invariant conditions. if (BinaryOperator *BO = dyn_cast(Cond)) if (BO->getOpcode() == Instruction::And || BO->getOpcode() == Instruction::Or) { - // If either the left or right side is invariant, we can unswitch on this, - // which will cause the branch to go away in one loop and the condition to - // simplify in the other one. - if (Value *LHS = - FindLIVLoopCondition(BO->getOperand(0), L, Changed, Cache)) { - Cache[Cond] = LHS; - return LHS; + // Given the previous operator, compute the current operator chain status. + OperatorChain NewChain; + switch (ParentChain) { + case OC_OpChainNone: + NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd : + OC_OpChainOr; + break; + case OC_OpChainOr: + NewChain = BO->getOpcode() == Instruction::Or ? OC_OpChainOr : + OC_OpChainMixed; + break; + case OC_OpChainAnd: + NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd : + OC_OpChainMixed; + break; + case OC_OpChainMixed: + NewChain = OC_OpChainMixed; + break; } - if (Value *RHS = - FindLIVLoopCondition(BO->getOperand(1), L, Changed, Cache)) { - Cache[Cond] = RHS; - return RHS; + + // If we reach a Mixed state, we do not want to keep walking up as we can not + // reliably find a value that will simplify the chain. With this check, we + // will return null on the first sight of mixed chain and the caller will + // either backtrack to find partial LIV in other operand or return null. + if (NewChain != OC_OpChainMixed) { + // Update the current operator chain type before we search up the chain. + ParentChain = NewChain; + // If either the left or right side is invariant, we can unswitch on this, + // which will cause the branch to go away in one loop and the condition to + // simplify in the other one. + if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed, + ParentChain, Cache)) { + Cache[Cond] = LHS; + return LHS; + } + // We did not manage to find a partial LIV in operand(0). Backtrack and try + // operand(1). + ParentChain = NewChain; + if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed, + ParentChain, Cache)) { + Cache[Cond] = RHS; + return RHS; + } } } @@ -416,9 +466,21 @@ return nullptr; } -static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) { +/// Cond is a condition that occurs in L. If it is invariant in the loop, or has +/// an invariant piece, return the invariant along with the operator chain type. +/// Otherwise, return null. +static std::pair FindLIVLoopCondition(Value *Cond, + Loop *L, + bool &Changed) { DenseMap Cache; - return FindLIVLoopCondition(Cond, L, Changed, Cache); + OperatorChain OpChain = OC_OpChainNone; + Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache); + + // In case we do find a LIV, it can not be obtained by walking up a mixed + // operator chain. + assert((!FCond || OpChain != OC_OpChainMixed) && + "Do not expect a partial LIV with mixed operator chain"); + return {FCond, OpChain}; } bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { @@ -550,7 +612,7 @@ for (IntrinsicInst *Guard : Guards) { Value *LoopCond = - FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed); + FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed).first; if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { // NB! Unswitching (if successful) could have erased some of the @@ -591,7 +653,7 @@ // See if this, or some part of it, is loop invariant. If so, we can // unswitch on it if we desire. Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), - currentLoop, Changed); + currentLoop, Changed).first; if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) { ++NumBranches; @@ -599,24 +661,49 @@ } } } else if (SwitchInst *SI = dyn_cast(TI)) { - Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), - currentLoop, Changed); + Value *SC = SI->getCondition(); + Value *LoopCond; + OperatorChain OpChain; + std::tie(LoopCond, OpChain) = + FindLIVLoopCondition(SC, currentLoop, Changed); + unsigned NumCases = SI->getNumCases(); if (LoopCond && NumCases) { // Find a value to unswitch on: // FIXME: this should chose the most expensive case! // FIXME: scan for a case with a non-critical edge? Constant *UnswitchVal = nullptr; - - // Do not process same value again and again. - // At this point we have some cases already unswitched and - // some not yet unswitched. Let's find the first not yet unswitched one. - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); - i != e; ++i) { - Constant *UnswitchValCandidate = i.getCaseValue(); - if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) { - UnswitchVal = UnswitchValCandidate; - break; + // Find a case value such that at least one case value is unswitched + // out. + if (OpChain == OC_OpChainAnd) { + // If the chain only has ANDs and the switch has a case value of 0. + // Dropping in a 0 to the chain will unswitch out the 0-casevalue. + auto *AllZero = cast(Constant::getNullValue(SC->getType())); + if (BranchesInfo.isUnswitched(SI, AllZero)) + continue; + // We are unswitching 0 out. + UnswitchVal = AllZero; + } else if (OpChain == OC_OpChainOr) { + // If the chain only has ORs and the switch has a case value of ~0. + // Dropping in a ~0 to the chain will unswitch out the ~0-casevalue. + auto *AllOne = cast(Constant::getAllOnesValue(SC->getType())); + if (BranchesInfo.isUnswitched(SI, AllOne)) + continue; + // We are unswitching ~0 out. + UnswitchVal = AllOne; + } else { + assert(OpChain == OC_OpChainNone && + "Expect to unswitch on trivial chain"); + // Do not process same value again and again. + // At this point we have some cases already unswitched and + // some not yet unswitched. Let's find the first not yet unswitched one. + for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); + i != e; ++i) { + Constant *UnswitchValCandidate = i.getCaseValue(); + if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) { + UnswitchVal = UnswitchValCandidate; + break; + } } } @@ -625,6 +712,11 @@ if (UnswitchIfProfitable(LoopCond, UnswitchVal)) { ++NumSwitches; + // In case of a full LIV, UnswitchVal is the value we unswitched out. + // In case of a partial LIV, we only unswitch when its an AND-chain + // or OR-chain. In both cases switch input value simplifies to + // UnswitchVal. + BranchesInfo.setUnswitched(SI, UnswitchVal); return true; } } @@ -635,7 +727,7 @@ BBI != E; ++BBI) if (SelectInst *SI = dyn_cast(BBI)) { Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), - currentLoop, Changed); + currentLoop, Changed).first; if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { ++NumSelects; @@ -894,7 +986,7 @@ return false; Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), - currentLoop, Changed); + currentLoop, Changed).first; // Unswitch only if the trivial condition itself is an LIV (not // partial LIV which could occur in and/or) @@ -925,7 +1017,7 @@ } else if (SwitchInst *SI = dyn_cast(CurrentTerm)) { // If this isn't switching on an invariant condition, we can't unswitch it. Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), - currentLoop, Changed); + currentLoop, Changed).first; // Unswitch only if the trivial condition itself is an LIV (not // partial LIV which could occur in and/or) @@ -963,6 +1055,9 @@ UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB, nullptr); + + // We are only unswitching full LIV. + BranchesInfo.setUnswitched(SI, CondVal); ++NumSwitches; return true; } @@ -1227,6 +1322,9 @@ SwitchInst *SI = dyn_cast(UI); if (!SI || !isa(Val)) continue; + // NOTE: if a case value for the switch is unswitched out, we record it + // after the unswitch finishes. We can not record it here as the switch + // is not a direct user of the partial LIV. SwitchInst::CaseIt DeadCase = SI->findCaseValue(cast(Val)); // Default case is live for multiple values. if (DeadCase == SI->case_default()) continue; @@ -1239,8 +1337,6 @@ BasicBlock *SISucc = DeadCase.getCaseSuccessor(); BasicBlock *Latch = L->getLoopLatch(); - BranchesInfo.setUnswitched(SI, Val); - if (!SI->findCaseDest(SISucc)) continue; // Edge is critical. // If the DeadCase successor dominates the loop latch, then the // transformation isn't safe since it will delete the sole predecessor edge Index: test/Transforms/LoopUnswitch/basictest.ll =================================================================== --- test/Transforms/LoopUnswitch/basictest.ll +++ test/Transforms/LoopUnswitch/basictest.ll @@ -101,6 +101,217 @@ ; CHECK: } } +; Make sure we unswitch %a == 0 out of the loop. +; +; CHECK: define void @and_i2_as_switch_input(i2 +; CHECK: entry: +; This is an indication that the loop has been unswitched. +; CHECK: icmp eq i2 %a, 0 +; CHECK: br +; There should be no more unswitching after the 1st unswitch. +; CHECK-NOT: icmp eq +; CHECK: ret +define void @and_i2_as_switch_input(i2 %a) { +entry: + br label %for.body + +for.body: + %i = phi i2 [ 0, %entry ], [ %inc, %for.inc ] + %and = and i2 %a, %i + %and1 = and i2 %and, %i + switch i2 %and1, label %sw.default [ + i2 0, label %sw.bb + i2 1, label %sw.bb1 + ] + +sw.bb: + br label %sw.epilog + +sw.bb1: + br label %sw.epilog + +sw.default: + br label %sw.epilog + +sw.epilog: + br label %for.inc + +for.inc: + %inc = add nsw i2 %i, 1 + %cmp = icmp slt i2 %inc, 3 + br i1 %cmp, label %for.body, label %for.end + +for.end: + ret void +} + +; Make sure we unswitch %a == !0 out of the loop. +; +; CHECK: define void @or_i2_as_switch_input(i2 +; CHECK: entry: +; This is an indication that the loop has been unswitched. +; CHECK: icmp eq i2 %a, -1 +; CHECK: br +; There should be no more unswitching after the 1st unswitch. +; CHECK-NOT: icmp eq +; CHECK: ret +define void @or_i2_as_switch_input(i2 %a) { +entry: + br label %for.body + +for.body: + %i = phi i2 [ 0, %entry ], [ %inc, %for.inc ] + %or = or i2 %a, %i + %or1 = or i2 %or, %i + switch i2 %or1, label %sw.default [ + i2 2, label %sw.bb + i2 3, label %sw.bb1 + ] + +sw.bb: + br label %sw.epilog + +sw.bb1: + br label %sw.epilog + +sw.default: + br label %sw.epilog + +sw.epilog: + br label %for.inc + +for.inc: + %inc = add nsw i2 %i, 1 + %cmp = icmp slt i2 %inc, 3 + br i1 %cmp, label %for.body, label %for.end + +for.end: + ret void +} + +; Make sure we unswitch %a == !0 out of the loop. Even we do not +; have it as a case value. Unswitching it out allows us to simplify +; the or operator chain. +; +; CHECK: define void @or_i2_as_switch_input_unswitch_default(i2 +; CHECK: entry: +; This is an indication that the loop has been unswitched. +; CHECK: icmp eq i2 %a, -1 +; CHECK: br +; There should be no more unswitching after the 1st unswitch. +; CHECK-NOT: icmp eq +; CHECK: ret +define void @or_i2_as_switch_input_unswitch_default(i2 %a) { +entry: + br label %for.body + +for.body: + %i = phi i2 [ 0, %entry ], [ %inc, %for.inc ] + %or = or i2 %a, %i + %or1 = or i2 %or, %i + switch i2 %or1, label %sw.default [ + i2 1, label %sw.bb + i2 2, label %sw.bb1 + ] + +sw.bb: + br label %sw.epilog + +sw.bb1: + br label %sw.epilog + +sw.default: + br label %sw.epilog + +sw.epilog: + br label %for.inc + +for.inc: + %inc = add nsw i2 %i, 1 + %cmp = icmp slt i2 %inc, 3 + br i1 %cmp, label %for.body, label %for.end + +for.end: + ret void +} + +; Make sure we don't unswitch, as we can not find an input value %a +; that will effectively unswitch 0 or 3 out of the loop. +; +; CHECK: define void @and_or_i2_as_switch_input(i2 +; CHECK: entry: +; This is an indication that the loop has NOT been unswitched. +; CHECK-NOT: icmp +; CHECK: br +define void @and_or_i2_as_switch_input(i2 %a) { +entry: + br label %for.body + +for.body: + %i = phi i2 [ 0, %entry ], [ %inc, %for.inc ] + %and = and i2 %a, %i + %or = or i2 %and, %i + switch i2 %or, label %sw.default [ + i2 0, label %sw.bb + i2 3, label %sw.bb1 + ] + +sw.bb: + br label %sw.epilog + +sw.bb1: + br label %sw.epilog + +sw.default: + br label %sw.epilog + +sw.epilog: + br label %for.inc + +for.inc: + %inc = add nsw i2 %i, 1 + %cmp = icmp slt i2 %inc, 3 + br i1 %cmp, label %for.body, label %for.end + +for.end: + ret void +} + +; Make sure we don't unswitch, as we can not find an input value %a +; that will effectively unswitch true/false out of the loop. +; +; CHECK: define void @and_or_i1_as_branch_input(i1 +; CHECK: entry: +; This is an indication that the loop has NOT been unswitched. +; CHECK-NOT: icmp +; CHECK: br +define void @and_or_i1_as_branch_input(i1 %a) { +entry: + br label %for.body + +for.body: + %i = phi i1 [ 0, %entry ], [ %inc, %for.inc ] + %and = and i1 %a, %i + %or = or i1 %and, %i + br i1 %or, label %sw.bb, label %sw.bb1 + +sw.bb: + br label %sw.epilog + +sw.bb1: + br label %sw.epilog + +sw.epilog: + br label %for.inc + +for.inc: + %inc = add nsw i1 %i, 1 + %cmp = icmp slt i1 %inc, 1 + br i1 %cmp, label %for.body, label %for.end + +for.end: + ret void +} declare void @incf() noreturn declare void @decf() noreturn