Changeset View
Changeset View
Standalone View
Standalone View
llvm/trunk/lib/Transforms/Utils/LowerSwitch.cpp
Show All 10 Lines | |||||
// switch instruction until it is convenient. | // switch instruction until it is convenient. | ||||
// | // | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
#include "llvm/ADT/DenseMap.h" | #include "llvm/ADT/DenseMap.h" | ||||
#include "llvm/ADT/STLExtras.h" | #include "llvm/ADT/STLExtras.h" | ||||
#include "llvm/ADT/SmallPtrSet.h" | #include "llvm/ADT/SmallPtrSet.h" | ||||
#include "llvm/ADT/SmallVector.h" | #include "llvm/ADT/SmallVector.h" | ||||
#include "llvm/Analysis/AssumptionCache.h" | |||||
#include "llvm/Analysis/LazyValueInfo.h" | |||||
#include "llvm/Analysis/ValueTracking.h" | |||||
#include "llvm/IR/BasicBlock.h" | #include "llvm/IR/BasicBlock.h" | ||||
#include "llvm/IR/CFG.h" | #include "llvm/IR/CFG.h" | ||||
#include "llvm/IR/ConstantRange.h" | |||||
#include "llvm/IR/Constants.h" | #include "llvm/IR/Constants.h" | ||||
#include "llvm/IR/Function.h" | #include "llvm/IR/Function.h" | ||||
#include "llvm/IR/InstrTypes.h" | #include "llvm/IR/InstrTypes.h" | ||||
#include "llvm/IR/Instructions.h" | #include "llvm/IR/Instructions.h" | ||||
#include "llvm/IR/Value.h" | #include "llvm/IR/Value.h" | ||||
#include "llvm/Pass.h" | #include "llvm/Pass.h" | ||||
#include "llvm/Support/Casting.h" | #include "llvm/Support/Casting.h" | ||||
#include "llvm/Support/Compiler.h" | #include "llvm/Support/Compiler.h" | ||||
#include "llvm/Support/Debug.h" | #include "llvm/Support/Debug.h" | ||||
#include "llvm/Support/KnownBits.h" | |||||
#include "llvm/Support/raw_ostream.h" | #include "llvm/Support/raw_ostream.h" | ||||
#include "llvm/Transforms/Utils.h" | #include "llvm/Transforms/Utils.h" | ||||
#include "llvm/Transforms/Utils/BasicBlockUtils.h" | #include "llvm/Transforms/Utils/BasicBlockUtils.h" | ||||
#include <algorithm> | #include <algorithm> | ||||
#include <cassert> | #include <cassert> | ||||
#include <cstdint> | #include <cstdint> | ||||
#include <iterator> | #include <iterator> | ||||
#include <limits> | #include <limits> | ||||
Show All 34 Lines | public: | ||||
static char ID; | static char ID; | ||||
LowerSwitch() : FunctionPass(ID) { | LowerSwitch() : FunctionPass(ID) { | ||||
initializeLowerSwitchPass(*PassRegistry::getPassRegistry()); | initializeLowerSwitchPass(*PassRegistry::getPassRegistry()); | ||||
} | } | ||||
bool runOnFunction(Function &F) override; | bool runOnFunction(Function &F) override; | ||||
void getAnalysisUsage(AnalysisUsage &AU) const override { | |||||
AU.addRequired<LazyValueInfoWrapperPass>(); | |||||
} | |||||
struct CaseRange { | struct CaseRange { | ||||
ConstantInt* Low; | ConstantInt* Low; | ||||
ConstantInt* High; | ConstantInt* High; | ||||
BasicBlock* BB; | BasicBlock* BB; | ||||
CaseRange(ConstantInt *low, ConstantInt *high, BasicBlock *bb) | CaseRange(ConstantInt *low, ConstantInt *high, BasicBlock *bb) | ||||
: Low(low), High(high), BB(bb) {} | : Low(low), High(high), BB(bb) {} | ||||
}; | }; | ||||
using CaseVector = std::vector<CaseRange>; | using CaseVector = std::vector<CaseRange>; | ||||
using CaseItr = std::vector<CaseRange>::iterator; | using CaseItr = std::vector<CaseRange>::iterator; | ||||
private: | private: | ||||
void processSwitchInst(SwitchInst *SI, SmallPtrSetImpl<BasicBlock*> &DeleteList); | void processSwitchInst(SwitchInst *SI, | ||||
SmallPtrSetImpl<BasicBlock *> &DeleteList, | |||||
AssumptionCache *AC, LazyValueInfo *LVI); | |||||
BasicBlock *switchConvert(CaseItr Begin, CaseItr End, | BasicBlock *switchConvert(CaseItr Begin, CaseItr End, | ||||
ConstantInt *LowerBound, ConstantInt *UpperBound, | ConstantInt *LowerBound, ConstantInt *UpperBound, | ||||
Value *Val, BasicBlock *Predecessor, | Value *Val, BasicBlock *Predecessor, | ||||
BasicBlock *OrigBlock, BasicBlock *Default, | BasicBlock *OrigBlock, BasicBlock *Default, | ||||
const std::vector<IntRange> &UnreachableRanges); | const std::vector<IntRange> &UnreachableRanges); | ||||
BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, BasicBlock *OrigBlock, | BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, | ||||
BasicBlock *Default); | ConstantInt *LowerBound, ConstantInt *UpperBound, | ||||
BasicBlock *OrigBlock, BasicBlock *Default); | |||||
unsigned Clusterify(CaseVector &Cases, SwitchInst *SI); | unsigned Clusterify(CaseVector &Cases, SwitchInst *SI); | ||||
}; | }; | ||||
/// The comparison function for sorting the switch case values in the vector. | /// The comparison function for sorting the switch case values in the vector. | ||||
/// WARNING: Case ranges should be disjoint! | /// WARNING: Case ranges should be disjoint! | ||||
struct CaseCmp { | struct CaseCmp { | ||||
bool operator()(const LowerSwitch::CaseRange& C1, | bool operator()(const LowerSwitch::CaseRange& C1, | ||||
const LowerSwitch::CaseRange& C2) { | const LowerSwitch::CaseRange& C2) { | ||||
const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low); | const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low); | ||||
const ConstantInt* CI2 = cast<const ConstantInt>(C2.High); | const ConstantInt* CI2 = cast<const ConstantInt>(C2.High); | ||||
return CI1->getValue().slt(CI2->getValue()); | return CI1->getValue().slt(CI2->getValue()); | ||||
} | } | ||||
}; | }; | ||||
} // end anonymous namespace | } // end anonymous namespace | ||||
char LowerSwitch::ID = 0; | char LowerSwitch::ID = 0; | ||||
// Publicly exposed interface to pass... | // Publicly exposed interface to pass... | ||||
char &llvm::LowerSwitchID = LowerSwitch::ID; | char &llvm::LowerSwitchID = LowerSwitch::ID; | ||||
INITIALIZE_PASS(LowerSwitch, "lowerswitch", | INITIALIZE_PASS_BEGIN(LowerSwitch, "lowerswitch", | ||||
"Lower SwitchInst's to branches", false, false) | |||||
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) | |||||
INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) | |||||
INITIALIZE_PASS_END(LowerSwitch, "lowerswitch", | |||||
"Lower SwitchInst's to branches", false, false) | "Lower SwitchInst's to branches", false, false) | ||||
// createLowerSwitchPass - Interface to this file... | // createLowerSwitchPass - Interface to this file... | ||||
FunctionPass *llvm::createLowerSwitchPass() { | FunctionPass *llvm::createLowerSwitchPass() { | ||||
return new LowerSwitch(); | return new LowerSwitch(); | ||||
} | } | ||||
bool LowerSwitch::runOnFunction(Function &F) { | bool LowerSwitch::runOnFunction(Function &F) { | ||||
LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); | |||||
auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>(); | |||||
AssumptionCache *AC = ACT ? &ACT->getAssumptionCache(F) : nullptr; | |||||
// Prevent LazyValueInfo from using the DominatorTree as LowerSwitch does not | |||||
// preserve it and it becomes stale (when available) pretty much immediately. | |||||
// Currently the DominatorTree is only used by LowerSwitch indirectly via LVI | |||||
// and computeKnownBits to refine isValidAssumeForContext's results. Given | |||||
// that the latter can handle some of the simple cases w/o a DominatorTree, | |||||
// it's easier to refrain from using the tree than to keep it up to date. | |||||
LVI->disableDT(); | |||||
bool Changed = false; | bool Changed = false; | ||||
SmallPtrSet<BasicBlock*, 8> DeleteList; | SmallPtrSet<BasicBlock*, 8> DeleteList; | ||||
for (Function::iterator I = F.begin(), E = F.end(); I != E; ) { | for (Function::iterator I = F.begin(), E = F.end(); I != E; ) { | ||||
BasicBlock *Cur = &*I++; // Advance over block so we don't traverse new blocks | BasicBlock *Cur = &*I++; // Advance over block so we don't traverse new blocks | ||||
// If the block is a dead Default block that will be deleted later, don't | // If the block is a dead Default block that will be deleted later, don't | ||||
// waste time processing it. | // waste time processing it. | ||||
if (DeleteList.count(Cur)) | if (DeleteList.count(Cur)) | ||||
continue; | continue; | ||||
if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { | if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { | ||||
Changed = true; | Changed = true; | ||||
processSwitchInst(SI, DeleteList); | processSwitchInst(SI, DeleteList, AC, LVI); | ||||
} | } | ||||
} | } | ||||
for (BasicBlock* BB: DeleteList) { | for (BasicBlock* BB: DeleteList) { | ||||
LVI->eraseBlock(BB); | |||||
DeleteDeadBlock(BB); | DeleteDeadBlock(BB); | ||||
} | } | ||||
return Changed; | return Changed; | ||||
} | } | ||||
/// Used for debugging purposes. | /// Used for debugging purposes. | ||||
LLVM_ATTRIBUTE_USED | LLVM_ATTRIBUTE_USED | ||||
static raw_ostream &operator<<(raw_ostream &O, | static raw_ostream &operator<<(raw_ostream &O, | ||||
const LowerSwitch::CaseVector &C) { | const LowerSwitch::CaseVector &C) { | ||||
O << "["; | O << "["; | ||||
for (LowerSwitch::CaseVector::const_iterator B = C.begin(), | for (LowerSwitch::CaseVector::const_iterator B = C.begin(), E = C.end(); | ||||
E = C.end(); B != E; ) { | B != E;) { | ||||
O << *B->Low << " -" << *B->High; | O << "[" << B->Low->getValue() << ", " << B->High->getValue() << "]"; | ||||
if (++B != E) O << ", "; | if (++B != E) | ||||
O << ", "; | |||||
} | } | ||||
return O << "]"; | return O << "]"; | ||||
} | } | ||||
/// Update the first occurrence of the "switch statement" BB in the PHI | /// Update the first occurrence of the "switch statement" BB in the PHI | ||||
/// node with the "new" BB. The other occurrences will: | /// node with the "new" BB. The other occurrences will: | ||||
/// | /// | ||||
/// 1) Be updated by subsequent calls to this function. Switch statements may | /// 1) Be updated by subsequent calls to this function. Switch statements may | ||||
/// have more than one outcoming edge into the same BB if they all have the same | /// have more than one outcoming edge into the same BB if they all have the same | ||||
/// value. When the switch statement is converted these incoming edges are now | /// value. When the switch statement is converted these incoming edges are now | ||||
/// coming from multiple BBs. | /// coming from multiple BBs. | ||||
/// 2) Removed if subsequent incoming values now share the same case, i.e., | /// 2) Removed if subsequent incoming values now share the same case, i.e., | ||||
/// multiple outcome edges are condensed into one. This is necessary to keep the | /// multiple outcome edges are condensed into one. This is necessary to keep the | ||||
/// number of phi values equal to the number of branches to SuccBB. | /// number of phi values equal to the number of branches to SuccBB. | ||||
static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, | static void | ||||
unsigned NumMergedCases) { | fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, | ||||
const unsigned NumMergedCases = std::numeric_limits<unsigned>::max()) { | |||||
for (BasicBlock::iterator I = SuccBB->begin(), | for (BasicBlock::iterator I = SuccBB->begin(), | ||||
IE = SuccBB->getFirstNonPHI()->getIterator(); | IE = SuccBB->getFirstNonPHI()->getIterator(); | ||||
I != IE; ++I) { | I != IE; ++I) { | ||||
PHINode *PN = cast<PHINode>(I); | PHINode *PN = cast<PHINode>(I); | ||||
// Only update the first occurrence. | // Only update the first occurrence. | ||||
unsigned Idx = 0, E = PN->getNumIncomingValues(); | unsigned Idx = 0, E = PN->getNumIncomingValues(); | ||||
unsigned LocalNumMergedCases = NumMergedCases; | unsigned LocalNumMergedCases = NumMergedCases; | ||||
Show All 25 Lines | |||||
/// a block emitted by one of the previous calls to switchConvert in the call | /// a block emitted by one of the previous calls to switchConvert in the call | ||||
/// stack. | /// stack. | ||||
BasicBlock * | BasicBlock * | ||||
LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, | LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, | ||||
ConstantInt *UpperBound, Value *Val, | ConstantInt *UpperBound, Value *Val, | ||||
BasicBlock *Predecessor, BasicBlock *OrigBlock, | BasicBlock *Predecessor, BasicBlock *OrigBlock, | ||||
BasicBlock *Default, | BasicBlock *Default, | ||||
const std::vector<IntRange> &UnreachableRanges) { | const std::vector<IntRange> &UnreachableRanges) { | ||||
assert(LowerBound && UpperBound && "Bounds must be initialized"); | |||||
unsigned Size = End - Begin; | unsigned Size = End - Begin; | ||||
if (Size == 1) { | if (Size == 1) { | ||||
// Check if the Case Range is perfectly squeezed in between | // Check if the Case Range is perfectly squeezed in between | ||||
// already checked Upper and Lower bounds. If it is then we can avoid | // already checked Upper and Lower bounds. If it is then we can avoid | ||||
// emitting the code that checks if the value actually falls in the range | // emitting the code that checks if the value actually falls in the range | ||||
// because the bounds already tell us so. | // because the bounds already tell us so. | ||||
if (Begin->Low == LowerBound && Begin->High == UpperBound) { | if (Begin->Low == LowerBound && Begin->High == UpperBound) { | ||||
unsigned NumMergedCases = 0; | unsigned NumMergedCases = 0; | ||||
if (LowerBound && UpperBound) | NumMergedCases = UpperBound->getSExtValue() - LowerBound->getSExtValue(); | ||||
NumMergedCases = | |||||
UpperBound->getSExtValue() - LowerBound->getSExtValue(); | |||||
fixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases); | fixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases); | ||||
return Begin->BB; | return Begin->BB; | ||||
} | } | ||||
return newLeafBlock(*Begin, Val, OrigBlock, Default); | return newLeafBlock(*Begin, Val, LowerBound, UpperBound, OrigBlock, | ||||
Default); | |||||
} | } | ||||
unsigned Mid = Size / 2; | unsigned Mid = Size / 2; | ||||
std::vector<CaseRange> LHS(Begin, Begin + Mid); | std::vector<CaseRange> LHS(Begin, Begin + Mid); | ||||
LLVM_DEBUG(dbgs() << "LHS: " << LHS << "\n"); | LLVM_DEBUG(dbgs() << "LHS: " << LHS << "\n"); | ||||
std::vector<CaseRange> RHS(Begin + Mid, End); | std::vector<CaseRange> RHS(Begin + Mid, End); | ||||
LLVM_DEBUG(dbgs() << "RHS: " << RHS << "\n"); | LLVM_DEBUG(dbgs() << "RHS: " << RHS << "\n"); | ||||
CaseRange &Pivot = *(Begin + Mid); | CaseRange &Pivot = *(Begin + Mid); | ||||
LLVM_DEBUG(dbgs() << "Pivot ==> " << Pivot.Low->getValue() << " -" | LLVM_DEBUG(dbgs() << "Pivot ==> [" << Pivot.Low->getValue() << ", " | ||||
<< Pivot.High->getValue() << "\n"); | << Pivot.High->getValue() << "]\n"); | ||||
// NewLowerBound here should never be the integer minimal value. | // NewLowerBound here should never be the integer minimal value. | ||||
// This is because it is computed from a case range that is never | // This is because it is computed from a case range that is never | ||||
// the smallest, so there is always a case range that has at least | // the smallest, so there is always a case range that has at least | ||||
// a smaller value. | // a smaller value. | ||||
ConstantInt *NewLowerBound = Pivot.Low; | ConstantInt *NewLowerBound = Pivot.Low; | ||||
// Because NewLowerBound is never the smallest representable integer | // Because NewLowerBound is never the smallest representable integer | ||||
// it is safe here to subtract one. | // it is safe here to subtract one. | ||||
ConstantInt *NewUpperBound = ConstantInt::get(NewLowerBound->getContext(), | ConstantInt *NewUpperBound = ConstantInt::get(NewLowerBound->getContext(), | ||||
NewLowerBound->getValue() - 1); | NewLowerBound->getValue() - 1); | ||||
if (!UnreachableRanges.empty()) { | if (!UnreachableRanges.empty()) { | ||||
// Check if the gap between LHS's highest and NewLowerBound is unreachable. | // Check if the gap between LHS's highest and NewLowerBound is unreachable. | ||||
int64_t GapLow = LHS.back().High->getSExtValue() + 1; | int64_t GapLow = LHS.back().High->getSExtValue() + 1; | ||||
int64_t GapHigh = NewLowerBound->getSExtValue() - 1; | int64_t GapHigh = NewLowerBound->getSExtValue() - 1; | ||||
IntRange Gap = { GapLow, GapHigh }; | IntRange Gap = { GapLow, GapHigh }; | ||||
if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges)) | if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges)) | ||||
NewUpperBound = LHS.back().High; | NewUpperBound = LHS.back().High; | ||||
} | } | ||||
LLVM_DEBUG(dbgs() << "LHS Bounds ==> "; if (LowerBound) { | LLVM_DEBUG(dbgs() << "LHS Bounds ==> [" << LowerBound->getSExtValue() << ", " | ||||
dbgs() << LowerBound->getSExtValue(); | << NewUpperBound->getSExtValue() << "]\n" | ||||
} else { dbgs() << "NONE"; } dbgs() << " - " | << "RHS Bounds ==> [" << NewLowerBound->getSExtValue() | ||||
<< NewUpperBound->getSExtValue() << "\n"; | << ", " << UpperBound->getSExtValue() << "]\n"); | ||||
dbgs() << "RHS Bounds ==> "; | |||||
dbgs() << NewLowerBound->getSExtValue() << " - "; if (UpperBound) { | |||||
dbgs() << UpperBound->getSExtValue() << "\n"; | |||||
} else { dbgs() << "NONE\n"; }); | |||||
// Create a new node that checks if the value is < pivot. Go to the | // Create a new node that checks if the value is < pivot. Go to the | ||||
// left branch if it is and right branch if not. | // left branch if it is and right branch if not. | ||||
Function* F = OrigBlock->getParent(); | Function* F = OrigBlock->getParent(); | ||||
BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock"); | BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock"); | ||||
ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, | ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, | ||||
Val, Pivot.Low, "Pivot"); | Val, Pivot.Low, "Pivot"); | ||||
Show All 11 Lines | LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, | ||||
BranchInst::Create(LBranch, RBranch, Comp, NewNode); | BranchInst::Create(LBranch, RBranch, Comp, NewNode); | ||||
return NewNode; | return NewNode; | ||||
} | } | ||||
/// Create a new leaf block for the binary lookup tree. It checks if the | /// Create a new leaf block for the binary lookup tree. It checks if the | ||||
/// switch's value == the case's value. If not, then it jumps to the default | /// switch's value == the case's value. If not, then it jumps to the default | ||||
/// branch. At this point in the tree, the value can't be another valid case | /// branch. At this point in the tree, the value can't be another valid case | ||||
/// value, so the jump to the "default" branch is warranted. | /// value, so the jump to the "default" branch is warranted. | ||||
BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, | BasicBlock *LowerSwitch::newLeafBlock(CaseRange &Leaf, Value *Val, | ||||
ConstantInt *LowerBound, | |||||
ConstantInt *UpperBound, | |||||
BasicBlock* OrigBlock, | BasicBlock *OrigBlock, | ||||
BasicBlock* Default) { | BasicBlock *Default) { | ||||
Function* F = OrigBlock->getParent(); | Function* F = OrigBlock->getParent(); | ||||
BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); | BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); | ||||
F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewLeaf); | F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewLeaf); | ||||
// Emit comparison | // Emit comparison | ||||
ICmpInst* Comp = nullptr; | ICmpInst* Comp = nullptr; | ||||
if (Leaf.Low == Leaf.High) { | if (Leaf.Low == Leaf.High) { | ||||
// Make the seteq instruction... | // Make the seteq instruction... | ||||
Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val, | Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val, | ||||
Leaf.Low, "SwitchLeaf"); | Leaf.Low, "SwitchLeaf"); | ||||
} else { | } else { | ||||
// Make range comparison | // Make range comparison | ||||
if (Leaf.Low->isMinValue(true /*isSigned*/)) { | if (Leaf.Low == LowerBound) { | ||||
// Val >= Min && Val <= Hi --> Val <= Hi | // Val >= Min && Val <= Hi --> Val <= Hi | ||||
Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, | Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, | ||||
"SwitchLeaf"); | "SwitchLeaf"); | ||||
} else if (Leaf.High == UpperBound) { | |||||
// Val <= Max && Val >= Lo --> Val >= Lo | |||||
Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SGE, Val, Leaf.Low, | |||||
"SwitchLeaf"); | |||||
} else if (Leaf.Low->isZero()) { | } else if (Leaf.Low->isZero()) { | ||||
// Val >= 0 && Val <= Hi --> Val <=u Hi | // Val >= 0 && Val <= Hi --> Val <=u Hi | ||||
Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, | Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, | ||||
"SwitchLeaf"); | "SwitchLeaf"); | ||||
} else { | } else { | ||||
// Emit V-Lo <=u Hi-Lo | // Emit V-Lo <=u Hi-Lo | ||||
Constant* NegLo = ConstantExpr::getNeg(Leaf.Low); | Constant* NegLo = ConstantExpr::getNeg(Leaf.Low); | ||||
Instruction* Add = BinaryOperator::CreateAdd(Val, NegLo, | Instruction* Add = BinaryOperator::CreateAdd(Val, NegLo, | ||||
Show All 23 Lines | for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { | ||||
int BlockIdx = PN->getBasicBlockIndex(OrigBlock); | int BlockIdx = PN->getBasicBlockIndex(OrigBlock); | ||||
assert(BlockIdx != -1 && "Switch didn't go to this successor??"); | assert(BlockIdx != -1 && "Switch didn't go to this successor??"); | ||||
PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); | PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); | ||||
} | } | ||||
return NewLeaf; | return NewLeaf; | ||||
} | } | ||||
/// Transform simple list of Cases into list of CaseRange's. | /// Transform simple list of \p SI's cases into list of CaseRange's \p Cases. | ||||
/// \post \p Cases wouldn't contain references to \p SI's default BB. | |||||
/// \returns Number of \p SI's cases that do not reference \p SI's default BB. | |||||
unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { | unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { | ||||
unsigned numCmps = 0; | unsigned NumSimpleCases = 0; | ||||
// Start with "simple" cases | // Start with "simple" cases | ||||
for (auto Case : SI->cases()) | for (auto Case : SI->cases()) { | ||||
if (Case.getCaseSuccessor() == SI->getDefaultDest()) | |||||
continue; | |||||
Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(), | Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(), | ||||
Case.getCaseSuccessor())); | Case.getCaseSuccessor())); | ||||
++NumSimpleCases; | |||||
} | |||||
llvm::sort(Cases, CaseCmp()); | llvm::sort(Cases, CaseCmp()); | ||||
// Merge case into clusters | // Merge case into clusters | ||||
if (Cases.size() >= 2) { | if (Cases.size() >= 2) { | ||||
CaseItr I = Cases.begin(); | CaseItr I = Cases.begin(); | ||||
for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) { | for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) { | ||||
int64_t nextValue = J->Low->getSExtValue(); | int64_t nextValue = J->Low->getSExtValue(); | ||||
Show All 9 Lines | for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) { | ||||
// FIXME: Combine branch weights. | // FIXME: Combine branch weights. | ||||
} else if (++I != J) { | } else if (++I != J) { | ||||
*I = *J; | *I = *J; | ||||
} | } | ||||
} | } | ||||
Cases.erase(std::next(I), Cases.end()); | Cases.erase(std::next(I), Cases.end()); | ||||
} | } | ||||
for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) { | return NumSimpleCases; | ||||
if (I->Low != I->High) | |||||
// A range counts double, since it requires two compares. | |||||
++numCmps; | |||||
} | } | ||||
return numCmps; | static ConstantRange getConstantRangeFromKnownBits(const KnownBits &Known) { | ||||
APInt Lower = Known.One; | |||||
APInt Upper = ~Known.Zero + 1; | |||||
if (Upper == Lower) | |||||
return ConstantRange(Known.getBitWidth(), /*isFullSet=*/true); | |||||
return ConstantRange(Lower, Upper); | |||||
} | } | ||||
/// Replace the specified switch instruction with a sequence of chained if-then | /// Replace the specified switch instruction with a sequence of chained if-then | ||||
/// insts in a balanced binary search. | /// insts in a balanced binary search. | ||||
void LowerSwitch::processSwitchInst(SwitchInst *SI, | void LowerSwitch::processSwitchInst(SwitchInst *SI, | ||||
SmallPtrSetImpl<BasicBlock*> &DeleteList) { | SmallPtrSetImpl<BasicBlock *> &DeleteList, | ||||
BasicBlock *CurBlock = SI->getParent(); | AssumptionCache *AC, LazyValueInfo *LVI) { | ||||
BasicBlock *OrigBlock = CurBlock; | BasicBlock *OrigBlock = SI->getParent(); | ||||
Function *F = CurBlock->getParent(); | Function *F = OrigBlock->getParent(); | ||||
Value *Val = SI->getCondition(); // The value we are switching on... | Value *Val = SI->getCondition(); // The value we are switching on... | ||||
BasicBlock* Default = SI->getDefaultDest(); | BasicBlock* Default = SI->getDefaultDest(); | ||||
// Don't handle unreachable blocks. If there are successors with phis, this | // Don't handle unreachable blocks. If there are successors with phis, this | ||||
// would leave them behind with missing predecessors. | // would leave them behind with missing predecessors. | ||||
if ((CurBlock != &F->getEntryBlock() && pred_empty(CurBlock)) || | if ((OrigBlock != &F->getEntryBlock() && pred_empty(OrigBlock)) || | ||||
CurBlock->getSinglePredecessor() == CurBlock) { | OrigBlock->getSinglePredecessor() == OrigBlock) { | ||||
DeleteList.insert(CurBlock); | DeleteList.insert(OrigBlock); | ||||
return; | return; | ||||
} | } | ||||
// Prepare cases vector. | |||||
CaseVector Cases; | |||||
const unsigned NumSimpleCases = Clusterify(Cases, SI); | |||||
LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() | |||||
<< ". Total non-default cases: " << NumSimpleCases | |||||
<< "\nCase clusters: " << Cases << "\n"); | |||||
// If there is only the default destination, just branch. | // If there is only the default destination, just branch. | ||||
if (!SI->getNumCases()) { | if (Cases.empty()) { | ||||
BranchInst::Create(Default, CurBlock); | BranchInst::Create(Default, OrigBlock); | ||||
// Remove all the references from Default's PHIs to OrigBlock, but one. | |||||
fixPhis(Default, OrigBlock, OrigBlock); | |||||
SI->eraseFromParent(); | SI->eraseFromParent(); | ||||
return; | return; | ||||
} | } | ||||
// Prepare cases vector. | |||||
CaseVector Cases; | |||||
unsigned numCmps = Clusterify(Cases, SI); | |||||
LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() | |||||
<< ". Total compares: " << numCmps << "\n"); | |||||
LLVM_DEBUG(dbgs() << "Cases: " << Cases << "\n"); | |||||
(void)numCmps; | |||||
ConstantInt *LowerBound = nullptr; | ConstantInt *LowerBound = nullptr; | ||||
ConstantInt *UpperBound = nullptr; | ConstantInt *UpperBound = nullptr; | ||||
std::vector<IntRange> UnreachableRanges; | bool DefaultIsUnreachableFromSwitch = false; | ||||
if (isa<UnreachableInst>(Default->getFirstNonPHIOrDbg())) { | if (isa<UnreachableInst>(Default->getFirstNonPHIOrDbg())) { | ||||
// Make the bounds tightly fitted around the case value range, because we | // Make the bounds tightly fitted around the case value range, because we | ||||
// know that the value passed to the switch must be exactly one of the case | // know that the value passed to the switch must be exactly one of the case | ||||
// values. | // values. | ||||
assert(!Cases.empty()); | |||||
LowerBound = Cases.front().Low; | LowerBound = Cases.front().Low; | ||||
UpperBound = Cases.back().High; | UpperBound = Cases.back().High; | ||||
DefaultIsUnreachableFromSwitch = true; | |||||
} else { | |||||
// Constraining the range of the value being switched over helps eliminating | |||||
// unreachable BBs and minimizing the number of `add` instructions | |||||
// newLeafBlock ends up emitting. Running CorrelatedValuePropagation after | |||||
// LowerSwitch isn't as good, and also much more expensive in terms of | |||||
// compile time for the following reasons: | |||||
// 1. it processes many kinds of instructions, not just switches; | |||||
// 2. even if limited to icmp instructions only, it will have to process | |||||
// roughly C icmp's per switch, where C is the number of cases in the | |||||
// switch, while LowerSwitch only needs to call LVI once per switch. | |||||
const DataLayout &DL = F->getParent()->getDataLayout(); | |||||
KnownBits Known = computeKnownBits(Val, DL, /*Depth=*/0, AC, SI); | |||||
ConstantRange KnownBitsRange = getConstantRangeFromKnownBits(Known); | |||||
const ConstantRange LVIRange = LVI->getConstantRange(Val, OrigBlock, SI); | |||||
ConstantRange ValRange = KnownBitsRange.intersectWith(LVIRange); | |||||
// We delegate removal of unreachable non-default cases to other passes. In | |||||
// the unlikely event that some of them survived, we just conservatively | |||||
// maintain the invariant that all the cases lie between the bounds. This | |||||
// may, however, still render the default case effectively unreachable. | |||||
APInt Low = Cases.front().Low->getValue(); | |||||
APInt High = Cases.back().High->getValue(); | |||||
APInt Min = APIntOps::smin(ValRange.getSignedMin(), Low); | |||||
APInt Max = APIntOps::smax(ValRange.getSignedMax(), High); | |||||
LowerBound = ConstantInt::get(SI->getContext(), Min); | |||||
UpperBound = ConstantInt::get(SI->getContext(), Max); | |||||
DefaultIsUnreachableFromSwitch = (Min + (NumSimpleCases - 1) == Max); | |||||
} | |||||
std::vector<IntRange> UnreachableRanges; | |||||
if (DefaultIsUnreachableFromSwitch) { | |||||
DenseMap<BasicBlock *, unsigned> Popularity; | DenseMap<BasicBlock *, unsigned> Popularity; | ||||
unsigned MaxPop = 0; | unsigned MaxPop = 0; | ||||
BasicBlock *PopSucc = nullptr; | BasicBlock *PopSucc = nullptr; | ||||
IntRange R = {std::numeric_limits<int64_t>::min(), | IntRange R = {std::numeric_limits<int64_t>::min(), | ||||
std::numeric_limits<int64_t>::max()}; | std::numeric_limits<int64_t>::max()}; | ||||
UnreachableRanges.push_back(R); | UnreachableRanges.push_back(R); | ||||
for (const auto &I : Cases) { | for (const auto &I : Cases) { | ||||
Show All 30 Lines | for (auto I = UnreachableRanges.begin(), E = UnreachableRanges.end(); | ||||
auto Next = I + 1; | auto Next = I + 1; | ||||
if (Next != E) { | if (Next != E) { | ||||
assert(Next->Low > I->High); | assert(Next->Low > I->High); | ||||
} | } | ||||
} | } | ||||
#endif | #endif | ||||
// As the default block in the switch is unreachable, update the PHI nodes | // As the default block in the switch is unreachable, update the PHI nodes | ||||
// (remove the entry to the default block) to reflect this. | // (remove all of the references to the default block) to reflect this. | ||||
const unsigned NumDefaultEdges = SI->getNumCases() + 1 - NumSimpleCases; | |||||
for (unsigned I = 0; I < NumDefaultEdges; ++I) | |||||
Default->removePredecessor(OrigBlock); | Default->removePredecessor(OrigBlock); | ||||
// Use the most popular block as the new default, reducing the number of | // Use the most popular block as the new default, reducing the number of | ||||
// cases. | // cases. | ||||
assert(MaxPop > 0 && PopSucc); | assert(MaxPop > 0 && PopSucc); | ||||
Default = PopSucc; | Default = PopSucc; | ||||
Cases.erase( | Cases.erase( | ||||
llvm::remove_if( | llvm::remove_if( | ||||
Cases, [PopSucc](const CaseRange &R) { return R.BB == PopSucc; }), | Cases, [PopSucc](const CaseRange &R) { return R.BB == PopSucc; }), | ||||
Cases.end()); | Cases.end()); | ||||
// If there are no cases left, just branch. | // If there are no cases left, just branch. | ||||
if (Cases.empty()) { | if (Cases.empty()) { | ||||
BranchInst::Create(Default, CurBlock); | BranchInst::Create(Default, OrigBlock); | ||||
SI->eraseFromParent(); | SI->eraseFromParent(); | ||||
// As all the cases have been replaced with a single branch, only keep | // As all the cases have been replaced with a single branch, only keep | ||||
// one entry in the PHI nodes. | // one entry in the PHI nodes. | ||||
for (unsigned I = 0 ; I < (MaxPop - 1) ; ++I) | for (unsigned I = 0 ; I < (MaxPop - 1) ; ++I) | ||||
PopSucc->removePredecessor(OrigBlock); | PopSucc->removePredecessor(OrigBlock); | ||||
return; | return; | ||||
} | } | ||||
} | } | ||||
unsigned NrOfDefaults = (SI->getDefaultDest() == Default) ? 1 : 0; | |||||
for (const auto &Case : SI->cases()) | |||||
if (Case.getCaseSuccessor() == Default) | |||||
NrOfDefaults++; | |||||
// Create a new, empty default block so that the new hierarchy of | // Create a new, empty default block so that the new hierarchy of | ||||
// if-then statements go to this and the PHI nodes are happy. | // if-then statements go to this and the PHI nodes are happy. | ||||
BasicBlock *NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); | BasicBlock *NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); | ||||
F->getBasicBlockList().insert(Default->getIterator(), NewDefault); | F->getBasicBlockList().insert(Default->getIterator(), NewDefault); | ||||
BranchInst::Create(Default, NewDefault); | BranchInst::Create(Default, NewDefault); | ||||
BasicBlock *SwitchBlock = | BasicBlock *SwitchBlock = | ||||
switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val, | switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val, | ||||
OrigBlock, OrigBlock, NewDefault, UnreachableRanges); | OrigBlock, OrigBlock, NewDefault, UnreachableRanges); | ||||
// If there are entries in any PHI nodes for the default edge, make sure | // If there are entries in any PHI nodes for the default edge, make sure | ||||
// to update them as well. | // to update them as well. | ||||
fixPhis(Default, OrigBlock, NewDefault, NrOfDefaults); | fixPhis(Default, OrigBlock, NewDefault); | ||||
// Branch to our shiny new if-then stuff... | // Branch to our shiny new if-then stuff... | ||||
BranchInst::Create(SwitchBlock, OrigBlock); | BranchInst::Create(SwitchBlock, OrigBlock); | ||||
// We are now done with the switch instruction, delete it. | // We are now done with the switch instruction, delete it. | ||||
BasicBlock *OldDefault = SI->getDefaultDest(); | BasicBlock *OldDefault = SI->getDefaultDest(); | ||||
CurBlock->getInstList().erase(SI); | OrigBlock->getInstList().erase(SI); | ||||
// If the Default block has no more predecessors just add it to DeleteList. | // If the Default block has no more predecessors just add it to DeleteList. | ||||
if (pred_begin(OldDefault) == pred_end(OldDefault)) | if (pred_begin(OldDefault) == pred_end(OldDefault)) | ||||
DeleteList.insert(OldDefault); | DeleteList.insert(OldDefault); | ||||
} | } |