Index: llvm/trunk/lib/Transforms/Utils/LowerSwitch.cpp =================================================================== --- llvm/trunk/lib/Transforms/Utils/LowerSwitch.cpp +++ llvm/trunk/lib/Transforms/Utils/LowerSwitch.cpp @@ -32,6 +32,23 @@ #define DEBUG_TYPE "lower-switch" namespace { + struct IntRange { + int64_t Low, High; + }; + // Return true iff R is covered by Ranges. + static bool IsInRanges(const IntRange &R, + const std::vector &Ranges) { + // Note: Ranges must be sorted, non-overlapping and non-adjacent. + + // Find the first range whose High field is >= R.High, + // then check if the Low field is <= R.Low. If so, we + // have a Range that covers R. + auto I = std::lower_bound( + Ranges.begin(), Ranges.end(), R, + [](const IntRange &A, const IntRange &B) { return A.High < B.High; }); + return I != Ranges.end() && I->Low <= R.Low; + } + /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch /// instructions. class LowerSwitch : public FunctionPass { @@ -68,7 +85,8 @@ BasicBlock *switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, ConstantInt *UpperBound, Value *Val, BasicBlock *Predecessor, - BasicBlock *OrigBlock, BasicBlock *Default); + BasicBlock *OrigBlock, BasicBlock *Default, + const std::vector &UnreachableRanges); BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, BasicBlock *OrigBlock, BasicBlock *Default); unsigned Clusterify(CaseVector &Cases, SwitchInst *SI); @@ -172,12 +190,12 @@ // LowerBound and UpperBound are used to keep track of the bounds for Val // that have already been checked by a block emitted by one of the previous // calls to switchConvert in the call stack. -BasicBlock *LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, - ConstantInt *LowerBound, - ConstantInt *UpperBound, Value *Val, - BasicBlock *Predecessor, - BasicBlock *OrigBlock, - BasicBlock *Default) { +BasicBlock * +LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, + ConstantInt *UpperBound, Value *Val, + BasicBlock *Predecessor, BasicBlock *OrigBlock, + BasicBlock *Default, + const std::vector &UnreachableRanges) { unsigned Size = End - Begin; if (Size == 1) { @@ -212,19 +230,19 @@ // the smallest, so there is always a case range that has at least // a smaller value. ConstantInt *NewLowerBound = cast(Pivot.Low); - ConstantInt *NewUpperBound; - // If we don't have a Default block then it means that we can never - // have a value outside of a case range, so set the UpperBound to the highest - // value in the LHS part of the case ranges. - if (Default != nullptr) { - // Because NewLowerBound is never the smallest representable integer - // it is safe here to subtract one. - NewUpperBound = ConstantInt::get(NewLowerBound->getContext(), - NewLowerBound->getValue() - 1); - } else { - CaseItr LastLHS = LHS.begin() + LHS.size() - 1; - NewUpperBound = cast(LastLHS->High); + // Because NewLowerBound is never the smallest representable integer + // it is safe here to subtract one. + ConstantInt *NewUpperBound = ConstantInt::get(NewLowerBound->getContext(), + NewLowerBound->getValue() - 1); + + if (!UnreachableRanges.empty()) { + // Check if the gap between LHS's highest and NewLowerBound is unreachable. + int64_t GapLow = cast(LHS.back().High)->getSExtValue() + 1; + int64_t GapHigh = NewLowerBound->getSExtValue() - 1; + IntRange Gap = { GapLow, GapHigh }; + if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges)) + NewUpperBound = cast(LHS.back().High); } DEBUG(dbgs() << "LHS Bounds ==> "; @@ -252,10 +270,10 @@ BasicBlock *LBranch = switchConvert(LHS.begin(), LHS.end(), LowerBound, NewUpperBound, Val, NewNode, OrigBlock, - Default); + Default, UnreachableRanges); BasicBlock *RBranch = switchConvert(RHS.begin(), RHS.end(), NewLowerBound, UpperBound, Val, NewNode, OrigBlock, - Default); + Default, UnreachableRanges); Function::iterator FI = OrigBlock; F->getBasicBlockList().insert(++FI, NewNode); @@ -380,26 +398,102 @@ Value *Val = SI->getCondition(); // The value we are switching on... BasicBlock* Default = SI->getDefaultDest(); - // If there is only the default destination, don't bother with the code below. + // If there is only the default destination, just branch. if (!SI->getNumCases()) { - BranchInst::Create(SI->getDefaultDest(), CurBlock); - CurBlock->getInstList().erase(SI); + BranchInst::Create(Default, CurBlock); + SI->eraseFromParent(); return; } - const bool DefaultIsUnreachable = - Default->size() == 1 && isa(Default->getTerminator()); + // Prepare cases vector. + CaseVector Cases; + unsigned numCmps = Clusterify(Cases, SI); + DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() + << ". Total compares: " << numCmps << "\n"); + DEBUG(dbgs() << "Cases: " << Cases << "\n"); + (void)numCmps; + + ConstantInt *LowerBound = nullptr; + ConstantInt *UpperBound = nullptr; + std::vector UnreachableRanges; + + if (isa(Default->getFirstNonPHIOrDbg())) { + // Make the bounds tightly fitted around the case value range, becase we + // know that the value passed to the switch must be exactly one of the case + // values. + assert(!Cases.empty()); + LowerBound = cast(Cases.front().Low); + UpperBound = cast(Cases.back().High); + + DenseMap Popularity; + unsigned MaxPop = 0; + BasicBlock *PopSucc = nullptr; + + IntRange R = { INT64_MIN, INT64_MAX }; + UnreachableRanges.push_back(R); + for (const auto &I : Cases) { + int64_t Low = cast(I.Low)->getSExtValue(); + int64_t High = cast(I.High)->getSExtValue(); + + IntRange &LastRange = UnreachableRanges.back(); + if (LastRange.Low == Low) { + // There is nothing left of the previous range. + UnreachableRanges.pop_back(); + } else { + // Terminate the previous range. + assert(Low > LastRange.Low); + LastRange.High = Low - 1; + } + if (High != INT64_MAX) { + IntRange R = { High + 1, INT64_MAX }; + UnreachableRanges.push_back(R); + } + + // Count popularity. + int64_t N = High - Low + 1; + unsigned &Pop = Popularity[I.BB]; + if ((Pop += N) > MaxPop) { + MaxPop = Pop; + PopSucc = I.BB; + } + } +#ifndef NDEBUG + /* UnreachableRanges should be sorted and the ranges non-adjacent. */ + for (auto I = UnreachableRanges.begin(), E = UnreachableRanges.end(); + I != E; ++I) { + assert(I->Low <= I->High); + auto Next = I + 1; + if (Next != E) { + assert(Next->Low > I->High); + } + } +#endif + + // Use the most popular block as the new default, reducing the number of + // cases. + assert(MaxPop > 0 && PopSucc); + Default = PopSucc; + for (CaseItr I = Cases.begin(); I != Cases.end();) { + if (I->BB == PopSucc) + I = Cases.erase(I); + else + ++I; + } + + // If there are no cases left, just branch. + if (Cases.empty()) { + BranchInst::Create(Default, CurBlock); + SI->eraseFromParent(); + return; + } + } + // Create a new, empty default block so that the new hierarchy of // if-then statements go to this and the PHI nodes are happy. - // if the default block is set as an unreachable we avoid creating one - // because will never be a valid target. - BasicBlock *NewDefault = nullptr; - if (!DefaultIsUnreachable) { - NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); - F->getBasicBlockList().insert(Default, NewDefault); + BasicBlock *NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); + F->getBasicBlockList().insert(Default, NewDefault); + BranchInst::Create(Default, NewDefault); - BranchInst::Create(Default, NewDefault); - } // If there is an entry in any PHI nodes for the default edge, make sure // to update them as well. for (BasicBlock::iterator I = Default->begin(); isa(I); ++I) { @@ -409,40 +503,18 @@ PN->setIncomingBlock((unsigned)BlockIdx, NewDefault); } - // Prepare cases vector. - CaseVector Cases; - unsigned numCmps = Clusterify(Cases, SI); - - DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() - << ". Total compares: " << numCmps << "\n"); - DEBUG(dbgs() << "Cases: " << Cases << "\n"); - (void)numCmps; - - ConstantInt *UpperBound = nullptr; - ConstantInt *LowerBound = nullptr; - - // Optimize the condition where Default is an unreachable block. In this case - // we can make the bounds tightly fitted around the case value ranges, - // because we know that the value passed to the switch should always be - // exactly one of the case values. - if (DefaultIsUnreachable) { - CaseItr LastCase = Cases.begin() + Cases.size() - 1; - UpperBound = cast(LastCase->High); - LowerBound = cast(Cases.begin()->Low); - } BasicBlock *SwitchBlock = switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val, - OrigBlock, OrigBlock, NewDefault); + OrigBlock, OrigBlock, NewDefault, UnreachableRanges); // Branch to our shiny new if-then stuff... BranchInst::Create(SwitchBlock, OrigBlock); // We are now done with the switch instruction, delete it. + BasicBlock *OldDefault = SI->getDefaultDest(); CurBlock->getInstList().erase(SI); - pred_iterator PI = pred_begin(Default), E = pred_end(Default); - // If the Default block has no more predecessors just remove it - if (PI == E) { - DeleteDeadBlock(Default); - } + // If the Default block has no more predecessors just remove it. + if (pred_begin(OldDefault) == pred_end(OldDefault)) + DeleteDeadBlock(OldDefault); } Index: llvm/trunk/test/Transforms/LowerSwitch/2014-06-11-SwitchDefaultUnreachableOpt.ll =================================================================== --- llvm/trunk/test/Transforms/LowerSwitch/2014-06-11-SwitchDefaultUnreachableOpt.ll +++ llvm/trunk/test/Transforms/LowerSwitch/2014-06-11-SwitchDefaultUnreachableOpt.ll @@ -1,5 +1,8 @@ ; RUN: opt < %s -lowerswitch -S | FileCheck %s -; CHECK-NOT: {{.*}}icmp eq{{.*}} +; +; The switch is lowered with a single icmp. +; CHECK: icmp +; CHECK-NOT: icmp ; ;int foo(int a) { ; @@ -14,7 +17,7 @@ ; ;} -define i32 @foo(i32 %a) nounwind ssp uwtable { +define i32 @foo(i32 %a) { %1 = alloca i32, align 4 %2 = alloca i32, align 4 store i32 %a, i32* %2, align 4 Index: llvm/trunk/test/Transforms/LowerSwitch/fold-popular-case-to-unreachable-default.ll =================================================================== --- llvm/trunk/test/Transforms/LowerSwitch/fold-popular-case-to-unreachable-default.ll +++ llvm/trunk/test/Transforms/LowerSwitch/fold-popular-case-to-unreachable-default.ll @@ -0,0 +1,110 @@ +; RUN: opt %s -lowerswitch -S | FileCheck %s + +define void @foo(i32 %x, i32* %p) { +; Cases 2 and 4 are removed and become the new default case. +; It is now enough to use two icmps to lower the switch. +; +; CHECK-LABEL: @foo +; CHECK: icmp slt i32 %x, 5 +; CHECK: icmp eq i32 %x, 1 +; CHECK-NOT: icmp +; +entry: + switch i32 %x, label %default [ + i32 1, label %bb0 + i32 2, label %popular + i32 4, label %popular + i32 5, label %bb1 + ] +bb0: + store i32 0, i32* %p + br label %exit +bb1: + store i32 1, i32* %p + br label %exit +popular: + store i32 2, i32* %p + br label %exit +exit: + ret void +default: + unreachable +} + +define void @unreachable_gap(i64 %x, i32* %p) { +; Cases 6 and INT64_MAX become the new default, but we still exploit the fact +; that 3-4 is unreachable, so four icmps is enough. + +; CHECK-LABEL: @unreachable_gap +; CHECK: icmp slt i64 %x, 2 +; CHECK: icmp slt i64 %x, 5 +; CHECK: icmp eq i64 %x, 5 +; CHECK: icmp slt i64 %x, 1 +; CHECK-NOT: icmp + +entry: + switch i64 %x, label %default [ + i64 -9223372036854775808, label %bb0 + i64 1, label %bb1 + i64 2, label %bb2 + i64 5, label %bb3 + i64 6, label %bb4 + i64 9223372036854775807, label %bb4 + ] +bb0: + store i32 0, i32* %p + br label %exit +bb1: + store i32 1, i32* %p + br label %exit +bb2: + store i32 2, i32* %p + br label %exit +bb3: + store i32 3, i32* %p + br label %exit +bb4: + store i32 4, i32* %p + br label %exit +exit: + ret void +default: + unreachable +} + + + +define void @nocases(i32 %x, i32* %p) { +; Don't fall over when there are no cases. +; +; CHECK-LABEL: @nocases +; CHECK-LABEL: entry +; CHECK-NEXT: br label %default +; +entry: + switch i32 %x, label %default [ + ] +default: + unreachable +} + +define void @nocasesleft(i32 %x, i32* %p) { +; Cases 2 and 4 are removed and we are left with no cases. +; +; CHECK-LABEL: @nocasesleft +; CHECK-LABEL: entry +; CHECK-NEXT: br label %popular +; +entry: + switch i32 %x, label %default [ + i32 2, label %popular + i32 4, label %popular + ] +popular: + store i32 2, i32* %p + br label %exit +exit: + ret void +default: + unreachable +}