diff --git a/llvm/lib/Transforms/Utils/LowerSwitch.cpp b/llvm/lib/Transforms/Utils/LowerSwitch.cpp --- a/llvm/lib/Transforms/Utils/LowerSwitch.cpp +++ b/llvm/lib/Transforms/Utils/LowerSwitch.cpp @@ -51,9 +51,9 @@ namespace { - struct IntRange { - int64_t Low, High; - }; +struct IntRange { + APInt Low, High; +}; } // end anonymous namespace @@ -66,8 +66,8 @@ // then check if the Low field is <= R.Low. If so, we // have a Range that covers R. auto I = llvm::lower_bound( - Ranges, R, [](IntRange A, IntRange B) { return A.High < B.High; }); - return I != Ranges.end() && I->Low <= R.Low; + Ranges, R, [](IntRange A, IntRange B) { return A.High.slt(B.High); }); + return I != Ranges.end() && I->Low.sle(R.Low); } struct CaseRange { @@ -273,10 +273,10 @@ if (!UnreachableRanges.empty()) { // Check if the gap between LHS's highest and NewLowerBound is unreachable. - int64_t GapLow = LHS.back().High->getSExtValue() + 1; - int64_t GapHigh = NewLowerBound->getSExtValue() - 1; - IntRange Gap = { GapLow, GapHigh }; - if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges)) + APInt GapLow = LHS.back().High->getValue() + 1; + APInt GapHigh = NewLowerBound->getValue() - 1; + IntRange Gap = {GapLow, GapHigh}; + if (GapHigh.sgt(GapLow) && IsInRanges(Gap, UnreachableRanges)) NewUpperBound = LHS.back().High; } @@ -287,11 +287,10 @@ // Create a new node that checks if the value is < pivot. Go to the // left branch if it is and right branch if not. - Function* F = OrigBlock->getParent(); - BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock"); + Function *F = OrigBlock->getParent(); + BasicBlock *NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock"); - ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, - Val, Pivot.Low, "Pivot"); + ICmpInst *Comp = new ICmpInst(ICmpInst::ICMP_SLT, Val, Pivot.Low, "Pivot"); BasicBlock *LBranch = SwitchConvert(LHS.begin(), LHS.end(), LowerBound, NewUpperBound, Val, @@ -328,14 +327,15 @@ if (Cases.size() >= 2) { CaseItr I = Cases.begin(); for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) { - int64_t nextValue = J->Low->getSExtValue(); - int64_t currentValue = I->High->getSExtValue(); - BasicBlock* nextBB = J->BB; - BasicBlock* currentBB = I->BB; + APInt nextValue = J->Low->getValue(); + APInt currentValue = I->High->getValue(); + BasicBlock *nextBB = J->BB; + BasicBlock *currentBB = I->BB; // If the two neighboring cases go to the same destination, merge them // into a single case. - assert(nextValue > currentValue && "Cases should be strictly ascending"); + assert(nextValue.sgt(currentValue) && + "Cases should be strictly ascending"); if ((nextValue == currentValue + 1) && (currentBB == nextBB)) { I->High = J->High; // FIXME: Combine branch weights. @@ -356,8 +356,8 @@ AssumptionCache *AC, LazyValueInfo *LVI) { BasicBlock *OrigBlock = SI->getParent(); Function *F = OrigBlock->getParent(); - Value *Val = SI->getCondition(); // The value we are switching on... - BasicBlock* Default = SI->getDefaultDest(); + Value *Val = SI->getCondition(); // The value we are switching on... + BasicBlock *Default = SI->getDefaultDest(); // Don't handle unreachable blocks. If there are successors with phis, this // would leave them behind with missing predecessors. @@ -428,35 +428,45 @@ std::vector UnreachableRanges; if (DefaultIsUnreachableFromSwitch) { - DenseMap Popularity; - unsigned MaxPop = 0; + unsigned BitWidth = Cases[0].Low->getBitWidth(); + APInt SignedZero(BitWidth, 0, true); + DenseMap Popularity; + APInt MaxPop(SignedZero); BasicBlock *PopSucc = nullptr; - IntRange R = {std::numeric_limits::min(), - std::numeric_limits::max()}; + APInt MaxCase(SignedZero); + // Signed max: 0b01111.... + MaxCase.setAllBits(); + MaxCase.setBitVal(BitWidth - 1, 0); + // Signed min: 0b10000.... + APInt MinCase(SignedZero); + MinCase.setBitVal(BitWidth - 1, 1); + IntRange R = {MinCase, MaxCase}; UnreachableRanges.push_back(R); for (const auto &I : Cases) { - int64_t Low = I.Low->getSExtValue(); - int64_t High = I.High->getSExtValue(); + APInt Low = I.Low->getValue(); + APInt High = I.High->getValue(); IntRange &LastRange = UnreachableRanges.back(); - if (LastRange.Low == Low) { + if (LastRange.Low.eq(Low)) { // There is nothing left of the previous range. UnreachableRanges.pop_back(); } else { // Terminate the previous range. - assert(Low > LastRange.Low); + assert(Low.sgt(LastRange.Low)); LastRange.High = Low - 1; } - if (High != std::numeric_limits::max()) { - IntRange R = { High + 1, std::numeric_limits::max() }; + if (High.ne(MaxCase)) { + IntRange R = {High + 1, MaxCase}; UnreachableRanges.push_back(R); } // Count popularity. - int64_t N = High - Low + 1; - unsigned &Pop = Popularity[I.BB]; - if ((Pop += N) > MaxPop) { + APInt N = High - Low + 1; + assert(N.sge(SignedZero)); + // Explict insert to make sure the bitwidth of APInts match + APInt &Pop = Popularity.insert({I.BB, APInt(SignedZero)}).first->second; + if ((Pop += N).sgt(MaxPop)) { MaxPop = Pop; PopSucc = I.BB; } @@ -465,10 +475,10 @@ /* UnreachableRanges should be sorted and the ranges non-adjacent. */ for (auto I = UnreachableRanges.begin(), E = UnreachableRanges.end(); I != E; ++I) { - assert(I->Low <= I->High); + assert(I->Low.sle(I->High)); auto Next = I + 1; if (Next != E) { - assert(Next->Low > I->High); + assert(Next->Low.sgt(I->High)); } } #endif @@ -481,7 +491,7 @@ // Use the most popular block as the new default, reducing the number of // cases. - assert(MaxPop > 0 && PopSucc); + assert(MaxPop.sgt(SignedZero) && PopSucc); Default = PopSucc; llvm::erase_if(Cases, [PopSucc](const CaseRange &R) { return R.BB == PopSucc; }); @@ -492,7 +502,7 @@ SI->eraseFromParent(); // As all the cases have been replaced with a single branch, only keep // one entry in the PHI nodes. - for (unsigned I = 0 ; I < (MaxPop - 1) ; ++I) + for (APInt I(SignedZero); I.slt(MaxPop - 1); ++I) PopSucc->removePredecessor(OrigBlock); return; } diff --git a/llvm/test/Transforms/LowerSwitch/pr59316.ll b/llvm/test/Transforms/LowerSwitch/pr59316.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerSwitch/pr59316.ll @@ -0,0 +1,18 @@ +; RUN: opt < %s -passes=lowerswitch + +define i64 @f(i1 %0, i128 %1, i64* %2) { +BB: + switch i128 %1, label %BB1 [ + i128 627, label %SW_C + i128 16201310291018008447, label %SW_C2 + ] + +BB1: ; preds = %SW_C, %BB + unreachable + +SW_C: ; preds = %SW_C, %BB + br i1 %0, label %BB1, label %SW_C + +SW_C2: ; preds = %BB + ret i64 0 +}