diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -5531,7 +5531,7 @@ /// destionations CaseDest corresponding to value CaseVal (0 for the default /// case), of a switch instruction SI. static bool -GetCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, +getCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, BasicBlock **CommonDest, SmallVectorImpl> &Res, const DataLayout &DL, const TargetTransformInfo &TTI) { @@ -5602,7 +5602,7 @@ // Helper function used to add CaseVal to the list of cases that generate // Result. Returns the updated number of cases that generate this result. -static uintptr_t MapCaseToResult(ConstantInt *CaseVal, +static uintptr_t mapCaseToResult(ConstantInt *CaseVal, SwitchCaseResultVectorTy &UniqueResults, Constant *Result) { for (auto &I : UniqueResults) { @@ -5621,7 +5621,7 @@ // instruction. Returns false if multiple PHI nodes have been found or if // there is not a common destination block for the switch. static bool -InitializeUniqueCases(SwitchInst *SI, PHINode *&PHI, BasicBlock *&CommonDest, +initializeUniqueCases(SwitchInst *SI, PHINode *&PHI, BasicBlock *&CommonDest, SwitchCaseResultVectorTy &UniqueResults, Constant *&DefaultResult, const DataLayout &DL, const TargetTransformInfo &TTI, @@ -5631,7 +5631,7 @@ // Resulting value at phi nodes for this case value. SwitchCaseResultsTy Results; - if (!GetCaseResults(SI, CaseVal, I.getCaseSuccessor(), &CommonDest, Results, + if (!getCaseResults(SI, CaseVal, I.getCaseSuccessor(), &CommonDest, Results, DL, TTI)) return false; @@ -5641,7 +5641,7 @@ // Add the case->result mapping to UniqueResults. const uintptr_t NumCasesForResult = - MapCaseToResult(CaseVal, UniqueResults, Results.begin()->second); + mapCaseToResult(CaseVal, UniqueResults, Results.begin()->second); // Early out if there are too many cases for this result. if (NumCasesForResult > MaxCasesPerResult) @@ -5660,7 +5660,7 @@ // Find the default result value. SmallVector, 1> DefaultResults; BasicBlock *DefaultDest = SI->getDefaultDest(); - GetCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, DefaultResults, + getCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, DefaultResults, DL, TTI); // If the default value is not found abort unless the default destination // is unreachable. @@ -5684,9 +5684,10 @@ // default: // return 4; // } -static Value *ConvertTwoCaseSwitch(const SwitchCaseResultVectorTy &ResultVector, - Constant *DefaultResult, Value *Condition, - IRBuilder<> &Builder) { +// TODO: Handle switches with more than 2 cases that map to the same result. +static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, + Constant *DefaultResult, Value *Condition, + IRBuilder<> &Builder) { // If we are selecting between only two cases transform into a simple // select or a two-way select if default is possible. if (ResultVector.size() == 2 && ResultVector[0].second.size() == 1 && @@ -5724,10 +5725,10 @@ // Helper function to cleanup a switch instruction that has been converted into // a select, fixing up PHI nodes and basic blocks. -static void RemoveSwitchAfterSelectConversion(SwitchInst *SI, PHINode *PHI, - Value *SelectValue, - IRBuilder<> &Builder, - DomTreeUpdater *DTU) { +static void removeSwitchAfterSelectFold(SwitchInst *SI, PHINode *PHI, + Value *SelectValue, + IRBuilder<> &Builder, + DomTreeUpdater *DTU) { std::vector Updates; BasicBlock *SelectBB = SI->getParent(); @@ -5758,33 +5759,32 @@ DTU->applyUpdates(Updates); } -/// If the switch is only used to initialize one or more -/// phi nodes in a common successor block with only two different -/// constant values, replace the switch with select. -static bool switchToSelect(SwitchInst *SI, IRBuilder<> &Builder, - DomTreeUpdater *DTU, const DataLayout &DL, - const TargetTransformInfo &TTI) { +/// If a switch is only used to initialize one or more phi nodes in a common +/// successor block with only two different constant values, try to replace the +/// switch with a select. Returns true if the fold was made. +static bool trySwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder, + DomTreeUpdater *DTU, const DataLayout &DL, + const TargetTransformInfo &TTI) { Value *const Cond = SI->getCondition(); PHINode *PHI = nullptr; BasicBlock *CommonDest = nullptr; Constant *DefaultResult; SwitchCaseResultVectorTy UniqueResults; // Collect all the cases that will deliver the same value from the switch. - if (!InitializeUniqueCases(SI, PHI, CommonDest, UniqueResults, DefaultResult, - DL, TTI, /*MaxUniqueResults*/2, - /*MaxCasesPerResult*/2)) + if (!initializeUniqueCases(SI, PHI, CommonDest, UniqueResults, DefaultResult, + DL, TTI, /*MaxUniqueResults*/ 2, + /*MaxCasesPerResult*/ 2)) return false; - assert(PHI != nullptr && "PHI for value select not found"); + assert(PHI != nullptr && "PHI for value select not found"); Builder.SetInsertPoint(SI); Value *SelectValue = - ConvertTwoCaseSwitch(UniqueResults, DefaultResult, Cond, Builder); - if (SelectValue) { - RemoveSwitchAfterSelectConversion(SI, PHI, SelectValue, Builder, DTU); - return true; - } - // The switch couldn't be converted into a select. - return false; + foldSwitchToSelect(UniqueResults, DefaultResult, Cond, Builder); + if (!SelectValue) + return false; + + removeSwitchAfterSelectFold(SI, PHI, SelectValue, Builder, DTU); + return true; } namespace { @@ -6237,7 +6237,7 @@ // Resulting value at phi nodes for this case value. using ResultsTy = SmallVector, 4>; ResultsTy Results; - if (!GetCaseResults(SI, CaseVal, CI->getCaseSuccessor(), &CommonDest, + if (!getCaseResults(SI, CaseVal, CI->getCaseSuccessor(), &CommonDest, Results, DL, TTI)) return false; @@ -6265,7 +6265,7 @@ // or a bitmask that fits in a register. SmallVector, 4> DefaultResultsList; bool HasDefaultResults = - GetCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, + getCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, DefaultResultsList, DL, TTI); bool NeedMask = (TableHasHoles && !HasDefaultResults); @@ -6569,7 +6569,7 @@ if (eliminateDeadSwitchCases(SI, DTU, Options.AC, DL)) return requestResimplify(); - if (switchToSelect(SI, Builder, DTU, DL, TTI)) + if (trySwitchToSelect(SI, Builder, DTU, DL, TTI)) return requestResimplify(); if (Options.ForwardSwitchCondToPhi && ForwardSwitchConditionToPHI(SI))