Index: lib/Transforms/Scalar/LoopUnswitch.cpp =================================================================== --- lib/Transforms/Scalar/LoopUnswitch.cpp +++ lib/Transforms/Scalar/LoopUnswitch.cpp @@ -368,9 +368,20 @@ return new LoopUnswitch(Os); } +/// Operator chain lattice. +enum OperatorChainTy { + OpChainNone, ///< There is no operator. + OpChainOr, ///< There are only ORs. + OpChainAnd, ///< There are only ANDs. + OpChainMixed ///< There are ANDs and ORs. +}; + /// Cond is a condition that occurs in L. If it is invariant in the loop, or has /// an invariant piece, return the invariant. Otherwise, return null. +// +/// NOTE: if operator chain pointer is null, we do not walk up the operator chain. static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, + OperatorChainTy *OCS, DenseMap &Cache) { auto CacheIt = Cache.find(Cond); if (CacheIt != Cache.end()) @@ -394,31 +405,59 @@ return Cond; } - if (BinaryOperator *BO = dyn_cast(Cond)) - if (BO->getOpcode() == Instruction::And || - BO->getOpcode() == Instruction::Or) { - // If either the left or right side is invariant, we can unswitch on this, - // which will cause the branch to go away in one loop and the condition to - // simplify in the other one. - if (Value *LHS = - FindLIVLoopCondition(BO->getOperand(0), L, Changed, Cache)) { - Cache[Cond] = LHS; - return LHS; - } - if (Value *RHS = - FindLIVLoopCondition(BO->getOperand(1), L, Changed, Cache)) { - Cache[Cond] = RHS; - return RHS; + // Walk up the operator chain to find partial invariant conditions. + if (OCS) + if (BinaryOperator *BO = dyn_cast(Cond)) + if (BO->getOpcode() == Instruction::And || + BO->getOpcode() == Instruction::Or) { + // Given the previous operator, compute the current operator chain status. + OperatorChainTy CChain = OpChainNone; + switch (*OCS) { + case OpChainNone: + CChain = BO->getOpcode() == Instruction::And ? OpChainAnd : OpChainOr; + break; + case OpChainOr: + CChain = BO->getOpcode() == Instruction::Or ? OpChainOr : OpChainMixed; + break; + case OpChainAnd: + CChain = BO->getOpcode() == Instruction::And ? OpChainAnd : OpChainMixed; + break; + default: + CChain = OpChainMixed; + break; + } + + // If we have an operator chain with ANDs and ORs, we do not want to keep + // walking up the chain as we can not reliably simplify the chain. + if (CChain != OpChainMixed) { + // Update the current operator chain type before we search up the chain. + *OCS = CChain; + // If either the left or right side is invariant, we can unswitch on this, + // which will cause the branch to go away in one loop and the condition to + // simplify in the other one. + if (Value *LHS = + FindLIVLoopCondition(BO->getOperand(0), L, Changed, OCS, Cache)) { + Cache[Cond] = LHS; + return LHS; + } + // Backtrack ... + *OCS = CChain; + if (Value *RHS = + FindLIVLoopCondition(BO->getOperand(1), L, Changed, OCS, Cache)) { + Cache[Cond] = RHS; + return RHS; + } + } } - } Cache[Cond] = nullptr; return nullptr; } -static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) { +static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, + OperatorChainTy *OCS = nullptr) { DenseMap Cache; - return FindLIVLoopCondition(Cond, L, Changed, Cache); + return FindLIVLoopCondition(Cond, L, Changed, OCS, Cache); } bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { @@ -549,8 +588,18 @@ return false; for (IntrinsicInst *Guard : Guards) { + OperatorChainTy OCS = OpChainNone; Value *LoopCond = - FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed); + FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed, &OCS); + // If we find the loop condition by walking up an operator chain containing + // ANDs and ORs, we can not simplify the condition in the original nor new + // loop into constants. + // + // NOTE: we may be able to simplify the condition itself to a loop variant + // condition, but at the cost of creating an entirely new loop. + if (OCS == OpChainMixed) + continue; + if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { // NB! Unswitching (if successful) could have erased some of the @@ -590,8 +639,18 @@ if (BI->isConditional()) { // See if this, or some part of it, is loop invariant. If so, we can // unswitch on it if we desire. + OperatorChainTy OCS = OpChainNone; Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), - currentLoop, Changed); + currentLoop, Changed, &OCS); + // If we find the loop condition by walking up an operator chain + // containing ANDs and ORs, we can not simplify the condition in the + // original nor new loop into constants. + // + // NOTE: we may be able to simplify the condition itself to a loop variant + // condition, but at the cost of creating an entirely new loop. + if (OCS == OpChainMixed) + continue; + if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) { ++NumBranches; @@ -599,24 +658,56 @@ } } } else if (SwitchInst *SI = dyn_cast(TI)) { - Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), - currentLoop, Changed); + OperatorChainTy OCS = OpChainNone; + Value *SC = SI->getCondition(); + Value *LoopCond = FindLIVLoopCondition(SC, currentLoop, Changed, &OCS); + // If we find the loop condition by walking up an operator chain + // containing ANDs and ORs, we can not simplify the condition in the + // original nor new loop into constants. + // + // NOTE: we may be able to simplify the condition itself to a loop variant + // condition, but at the cost of creating an entirely new loop. + if (OCS == OpChainMixed) + continue; + unsigned NumCases = SI->getNumCases(); if (LoopCond && NumCases) { // Find a value to unswitch on: // FIXME: this should chose the most expensive case! // FIXME: scan for a case with a non-critical edge? Constant *UnswitchVal = nullptr; - - // Do not process same value again and again. - // At this point we have some cases already unswitched and - // some not yet unswitched. Let's find the first not yet unswitched one. - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); - i != e; ++i) { - Constant *UnswitchValCandidate = i.getCaseValue(); - if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) { - UnswitchVal = UnswitchValCandidate; - break; + // Find a case value such that at least one case value is unswitched + // out. + if (OCS == OpChainAnd) { + // If the chain only has ANDs and the switch has a case value of 0. + // Dropping in a 0 to the chain will unswitch out the 0-casevalue. + auto AllZero = cast(Constant::getNullValue(SC->getType())); + auto CaseVal = SI->findCaseValue(AllZero); + if (CaseVal != SI->case_default() && + !BranchesInfo.isUnswitched(SI, AllZero)) { + UnswitchVal = AllZero; + } + } else if (OCS == OpChainOr) { + // If the chain only has ORs and the switch has a case value of ~0. + // Dropping in a ~0 to the chain will unswitch out the ~0-casevalue. + auto AllOne = cast(Constant::getAllOnesValue(SC->getType())); + auto CaseVal = SI->findCaseValue(AllOne); + if (CaseVal != SI->case_default() && + !BranchesInfo.isUnswitched(SI, AllOne)) { + UnswitchVal = AllOne; + } + } else { + assert(OCS == OpChainNone && "Expect to unswitch on trivial operator chain"); + // Do not process same value again and again. + // At this point we have some cases already unswitched and + // some not yet unswitched. Let's find the first not yet unswitched one. + for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); + i != e; ++i) { + Constant *UnswitchValCandidate = i.getCaseValue(); + if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) { + UnswitchVal = UnswitchValCandidate; + break; + } } } @@ -634,8 +725,18 @@ for (BasicBlock::iterator BBI = (*I)->begin(), E = (*I)->end(); BBI != E; ++BBI) if (SelectInst *SI = dyn_cast(BBI)) { + OperatorChainTy OCS = OpChainNone; Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), - currentLoop, Changed); + currentLoop, Changed, &OCS); + // If we find the loop condition by walking up an operator chain containing + // ANDs and ORs, we can not simplify the condition in the original nor new + // loop into constants. + // + // NOTE: we may be able to simplify the condition itself to a loop variant + // condition, but at the cost of creating an entirely new loop. + if (OCS == OpChainMixed) + continue; + if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { ++NumSelects; @@ -898,9 +999,11 @@ // Unswitch only if the trivial condition itself is an LIV (not // partial LIV which could occur in and/or) - if (!LoopCond || LoopCond != BI->getCondition()) + if (!LoopCond) return false; + // We did not walk up the operator chain. + assert(LoopCond == BI->getCondition() && "Should unswitch on branch cond"); // Check to see if a successor of the branch is guaranteed to // exit through a unique exit block without having any // side-effects. If so, determine the value of Cond that causes @@ -929,9 +1032,11 @@ // Unswitch only if the trivial condition itself is an LIV (not // partial LIV which could occur in and/or) - if (!LoopCond || LoopCond != SI->getCondition()) + if (!LoopCond) return false; + // We did not walk up the operator chain. + assert(LoopCond == SI->getCondition() && "Should unswitch on switch cond"); // Check to see if a successor of the switch is guaranteed to go to the // latch block or exit through a one exit block without having any // side-effects. If so, determine the value of Cond that causes it to do Index: test/Transforms/LoopUnswitch/basictest.ll =================================================================== --- test/Transforms/LoopUnswitch/basictest.ll +++ test/Transforms/LoopUnswitch/basictest.ll @@ -101,6 +101,165 @@ ; CHECK: } } +; Make sure we unswitch %a == 0 out of the loop. +; +; CHECK: define void @and_i2_as_switch_input(i2 +; CHECK: entry: +; This is an indication that the loop has been unswitched. +; CHECK: icmp eq i2 %a, 0 +; CHECK: br +define void @and_i2_as_switch_input(i2 %a) { +entry: + br label %for.body + +for.body: + %i = phi i2 [ 0, %entry ], [ %inc, %for.inc ] + %and = and i2 %a, %i + switch i2 %and, label %sw.default [ + i2 0, label %sw.bb + i2 1, label %sw.bb1 + ] + +sw.bb: + br label %sw.epilog + +sw.bb1: + br label %sw.epilog + +sw.default: + br label %sw.epilog + +sw.epilog: + br label %for.inc + +for.inc: + %inc = add nsw i2 %i, 1 + %cmp = icmp slt i2 %inc, 3 + br i1 %cmp, label %for.body, label %for.end + +for.end: + ret void +} + +; Make sure we unswitch %a == !0 out of the loop. +; +; CHECK: define void @or_i2_as_switch_input(i2 +; CHECK: entry: +; This is an indication that the loop has been unswitched. +; CHECK: icmp eq i2 %a, -1 +; CHECK: br +define void @or_i2_as_switch_input(i2 %a) { +entry: + br label %for.body + +for.body: + %i = phi i2 [ 0, %entry ], [ %inc, %for.inc ] + %or = or i2 %a, %i + switch i2 %or, label %sw.default [ + i2 2, label %sw.bb + i2 3, label %sw.bb1 + ] + +sw.bb: + br label %sw.epilog + +sw.bb1: + br label %sw.epilog + +sw.default: + br label %sw.epilog + +sw.epilog: + br label %for.inc + +for.inc: + %inc = add nsw i2 %i, 1 + %cmp = icmp slt i2 %inc, 3 + br i1 %cmp, label %for.body, label %for.end + +for.end: + ret void +} + +; Make sure we dont unswitch, as we can not find a input value %a +; that will effectively unswitch 0 or 3 out of the loop. +; +; CHECK: define void @and_or_i2_as_switch_input(i2 +; CHECK: entry: +; This is an indication that the loop has NOT been unswitched. +; CHECK-NOT: icmp +; CHECK: br +define void @and_or_i2_as_switch_input(i2 %a) { +entry: + br label %for.body + +for.body: + %i = phi i2 [ 0, %entry ], [ %inc, %for.inc ] + %and = and i2 %a, %i + %or = or i2 %and, %i + switch i2 %or, label %sw.default [ + i2 0, label %sw.bb + i2 3, label %sw.bb1 + ] + +sw.bb: + br label %sw.epilog + +sw.bb1: + br label %sw.epilog + +sw.default: + br label %sw.epilog + +sw.epilog: + br label %for.inc + +for.inc: + %inc = add nsw i2 %i, 1 + %cmp = icmp slt i2 %inc, 3 + br i1 %cmp, label %for.body, label %for.end + +for.end: + ret void +} + + +; Make sure we dont unswitch, as we can not find a input value %a +; that will effectively unswitch true/false out of the loop. +; +; CHECK: define void @and_or_i1_as_branch_input(i1 +; CHECK: entry: +; This is an indication that the loop has NOT been unswitched. +; CHECK-NOT: icmp +; CHECK: br +define void @and_or_i1_as_branch_input(i1 %a) { +entry: + br label %for.body + +for.body: + %i = phi i1 [ 0, %entry ], [ %inc, %for.inc ] + %and = and i1 %a, %i + %or = or i1 %and, %i + br i1 %or, label %sw.bb, label %sw.bb1 + +sw.bb: + br label %sw.epilog + +sw.bb1: + br label %sw.epilog + +sw.epilog: + br label %for.inc + +for.inc: + %inc = add nsw i1 %i, 1 + %cmp = icmp slt i1 %inc, 1 + br i1 %cmp, label %for.body, label %for.end + +for.end: + ret void +} + declare void @incf() noreturn declare void @decf() noreturn