Index: lib/Transforms/Utils/SimplifyCFG.cpp =================================================================== --- lib/Transforms/Utils/SimplifyCFG.cpp +++ lib/Transforms/Utils/SimplifyCFG.cpp @@ -135,6 +135,8 @@ NumLookupTablesHoles, "Number of switch instructions turned into lookup tables (holes checked)"); STATISTIC(NumTableCmpReuses, "Number of reused switch table lookup compares"); +STATISTIC(NumSwitchWidened, + "Number of switch instructions widened to cover all possible values"); STATISTIC(NumSinkCommons, "Number of common instructions sunk down to the end block"); STATISTIC(NumSpeculations, "Number of speculative executed instructions"); @@ -4837,6 +4839,141 @@ SI->eraseFromParent(); } +// Analyze the switch expression and compute the highest possible value of the +// expression and if known then remap the known default values into individual +// case statements and direct them to the default block. Then mark the default +// destination as unreachable. +static bool widenSwitchRangeToCoverAllValues(SwitchInst *SI, + const DataLayout &DL, + const TargetTransformInfo &TTI) { + unsigned ValueWidth = DL.getTypeSizeInBits(SI->getOperand(0)->getType()); + KnownBits Known(ValueWidth); + llvm::computeKnownBits(SI->getOperand(0), Known, DL); + + bool UnreachableDefault = + isa(SI->getDefaultDest()->getFirstNonPHIOrDbg()); + + // We only care for non-negative values. The number of minimum leading + // zeros plus the number of maximum trailing ones should be equal to the + // width of the value. If the default is unreachable then we skip this + // optimization. + if (Known.isNegative() || UnreachableDefault || + (Known.countMinLeadingZeros() + Known.countMaxTrailingOnes()) != + ValueWidth) + return false; + + APInt NumCases(ValueWidth + 1, SI->getNumCases()); + APInt MaxSwitchValue = (~Known.Zero).zext(ValueWidth + 1); + + // At least 70% of the possible values must be defined as case statements + // so that there is no penalty for those which have few defined case + // statements. + if (MaxSwitchValue.ugt(NumCases) && + SI->getNumCases() > 0 && + (NumCases * 100).udiv(MaxSwitchValue).uge(APInt(ValueWidth + 1, 70))) { + // Here we not only need to care for those values that are beyond the + // highest case value and up to the MaxSwitchValue but also for the other + // missing values (holes) that appear between the already defined cases. + BasicBlock *DefaultBlock = SI->getDefaultDest(); + SmallDenseMap DefaultResults; + SmallVector PHIs; + + // Resulting value at phi nodes for this switch statement. + SmallVector, 4> DefaultResultsList; + BasicBlock *CommonDest = nullptr; + if (DefaultBlock->phis().end() != DefaultBlock->phis().begin() && + ! GetCaseResults(SI, nullptr, DefaultBlock, &CommonDest, + DefaultResultsList, DL, TTI)) + return false; + + for (const auto &I : DefaultResultsList) { + PHINode *PHI = I.first; + Value *Result = I.second; + DefaultResults[PHI] = Result; + PHIs.push_back(PHI); + } + + assert(SI->case_begin() != SI->case_end()); + SmallVector Cases; + + for (auto Case : SI->cases()) + Cases.push_back(Case.getCaseValue()); + + // The function CasesAreContiguous sorts the vector Cases in descending + // order. + if (CasesAreContiguous(Cases)) { + for (APInt I = (Cases[0]->getValue().zext(ValueWidth + 1) + 1); + I.ule(MaxSwitchValue); I++) { + SI->addCase(ConstantInt::get(DefaultBlock->getContext(), + I.trunc(ValueWidth)), DefaultBlock); + if (MaxSwitchValue.ne(I)) + for (PHINode *PHI : PHIs) + PHI->addIncoming(DefaultResults[PHI], SI->getParent()); + } + } else { + unsigned int i = 0; + bool SeenFirstDefaultCase = false; + for (APInt I = MaxSwitchValue, E(ValueWidth + 1, -1); I.ne(E); I--) { + if (i >= Cases.size() || + Cases[i]->getValue().zext(ValueWidth + 1) != I) { + SI->addCase(ConstantInt::get(DefaultBlock->getContext(), + I.trunc(ValueWidth)), DefaultBlock); + if (SeenFirstDefaultCase) { + for (PHINode *PHI : PHIs) + PHI->addIncoming(DefaultResults[PHI], SI->getParent()); + } else { + SeenFirstDefaultCase = true; + } + } else { + i++; + } + } + } +#if 0 + ConstantInt *MaxCaseVal = CI->getCaseValue(); + SwitchInst::CaseIt CI = SI->case_begin(); + for (int i = 0, SwitchInst::CaseIt E = SI->case_end(); CI != E; ++CI) { + ConstantInt *CaseVal = CI->getCaseValue(); + const APInt Val(ValueWidth, i, true); + ConstantInt *IVal = ConstantInt::get(DefaultBlock->getContext(), Val); + if + } + for (SwitchInst::CaseIt E = SI->case_end(); CI != E; ++CI) + if (CI->getCaseValue()->getValue().sgt(MaxCaseVal->getValue())) + MaxCaseVal = CI->getCaseValue(); + + for (int i = MaxCaseVal->getSExtValue() + 1; MaxSwitchValue.uge(i); i++) { + const APInt Val(ValueWidth, i, true); + const APInt ValUnsigned(ValueWidth, i); + SI->addCase(ConstantInt::get(DefaultBlock->getContext(), Val), + DefaultBlock); + if (MaxSwitchValue.ne(ValUnsigned)) { + for (PHINode *PHI : PHIs) { + PHI->addIncoming(DefaultResults[PHI], SI->getParent()); + } + } + } +#endif + + // We have already added cases for all the possible values including those + // that would have mapped to the default case and redirected the newly + // added cases to the original default block. Hence, the new default block + // can be marked as unreachable. + BasicBlock *NewDefaultBlock; + NewDefaultBlock = BasicBlock::Create(DefaultBlock->getContext(), + "switch.unreachable.default", + DefaultBlock->getParent()); + UnreachableInst *dummyInst LLVM_ATTRIBUTE_UNUSED; + dummyInst = new UnreachableInst(NewDefaultBlock->getContext(), + NewDefaultBlock); + SI->setDefaultDest(NewDefaultBlock); + ++NumSwitchWidened; + return true; + } + + return false; +} + /// If the switch is only used to initialize one or more /// phi nodes in a common successor block with only two different /// constant values, replace the switch with select. @@ -5609,6 +5746,11 @@ return simplifyCFG(BB, TTI, Options) || true; } + // Try to convert small number of default values to individual case + // statements pointing to the default block. + if (widenSwitchRangeToCoverAllValues(SI, DL, TTI)) + return simplifyCFG(BB, TTI, Options) || true; + // Try to transform the switch into an icmp and a branch. if (TurnSwitchRangeIntoICmp(SI, Builder)) return simplifyCFG(BB, TTI, Options) || true;