Index: include/llvm/ADT/iterator.h =================================================================== --- include/llvm/ADT/iterator.h +++ include/llvm/ADT/iterator.h @@ -165,9 +165,15 @@ return !static_cast(this)->operator<(RHS); } + PointerT operator->() { return &static_cast(this)->operator*(); } PointerT operator->() const { return &static_cast(this)->operator*(); } + ReferenceProxy operator[](DifferenceTypeT n) { + static_assert(IsRandomAccess, + "Subscripting is only defined for random access iterators."); + return ReferenceProxy(static_cast(this)->operator+(n)); + } ReferenceProxy operator[](DifferenceTypeT n) const { static_assert(IsRandomAccess, "Subscripting is only defined for random access iterators."); Index: include/llvm/Analysis/CFGPrinter.h =================================================================== --- include/llvm/Analysis/CFGPrinter.h +++ include/llvm/Analysis/CFGPrinter.h @@ -140,8 +140,7 @@ std::string Str; raw_string_ostream OS(Str); - SwitchInst::ConstCaseIt Case = - SwitchInst::ConstCaseIt::fromSuccessorIndex(SI, SuccNo); + auto Case = *SwitchInst::ConstCaseIt::fromSuccessorIndex(SI, SuccNo); OS << Case.getCaseValue()->getValue(); return OS.str(); } Index: include/llvm/IR/Instructions.h =================================================================== --- include/llvm/IR/Instructions.h +++ include/llvm/IR/Instructions.h @@ -3093,49 +3093,39 @@ // -2 static const unsigned DefaultPseudoIndex = static_cast(~0L-1); - template - class CaseIteratorT - : public iterator_facade_base< - CaseIteratorT, - std::random_access_iterator_tag, - CaseIteratorT> { + template class CaseIteratorT; + + /// A handle to a particular switch case. It exposes a convenient interface + /// to both the case value and the successor block. + /// + /// We define this as a template and instantiate it to form both a const and + /// non-const handle. + template + class CaseHandleT { + // Directly befriend both const and non-const iterators. + friend class SwitchInst::CaseIteratorT< + CaseHandleT>; + protected: + // Expose the switch type we're parameterized with to the iterator. + typedef SwitchInstT SwitchInstType; + SwitchInstT *SI; ptrdiff_t Index; - public: - typedef CaseIteratorT Self; - - /// Default constructed iterator is in an invalid state until assigned to - /// a case for a particular switch. - CaseIteratorT() : SI(nullptr) {} - - /// Initializes case iterator for given SwitchInst and for given - /// case number. - CaseIteratorT(SwitchInstT *SI, unsigned CaseNum) { - this->SI = SI; - Index = CaseNum; - } - - /// Initializes case iterator for given SwitchInst and for given - /// TerminatorInst's successor index. - static Self fromSuccessorIndex(SwitchInstT *SI, unsigned SuccessorIndex) { - assert(SuccessorIndex < SI->getNumSuccessors() && - "Successor index # out of range!"); - return SuccessorIndex != 0 ? - Self(SI, SuccessorIndex - 1) : - Self(SI, DefaultPseudoIndex); - } + CaseHandleT() = default; + CaseHandleT(SwitchInstT *SI, ptrdiff_t Index) : SI(SI), Index(Index) {} + public: /// Resolves case value for current case. - ConstantIntT *getCaseValue() { + ConstantIntT *getCaseValue() const { assert((unsigned)Index < SI->getNumCases() && "Index out the number of cases."); return reinterpret_cast(SI->getOperand(2 + Index * 2)); } /// Resolves successor for current case. - BasicBlockT *getCaseSuccessor() { + BasicBlockT *getCaseSuccessor() const { assert(((unsigned)Index < SI->getNumCases() || (unsigned)Index == DefaultPseudoIndex) && "Index out the number of cases."); @@ -3153,43 +3143,20 @@ return (unsigned)Index != DefaultPseudoIndex ? Index + 1 : 0; } - Self &operator+=(ptrdiff_t N) { - // Check index correctness after addition. - // Note: Index == getNumCases() means end(). - assert(Index + N >= 0 && (unsigned)(Index + N) <= SI->getNumCases() && - "Index out the number of cases."); - Index += N; - return *this; - } - Self &operator-=(ptrdiff_t N) { - // Check index correctness after subtraction. - // Note: Index == getNumCases() means end(). - assert(Index - N >= 0 && (unsigned)(Index - N) <= SI->getNumCases() && - "Index out the number of cases."); - Index -= N; - return *this; - } - bool operator==(const Self& RHS) const { + bool operator==(const CaseHandleT &RHS) const { assert(SI == RHS.SI && "Incompatible operators."); return Index == RHS.Index; } - bool operator<(const Self& RHS) const { - assert(SI == RHS.SI && "Incompatible operators."); - return Index < RHS.Index; - } - Self &operator*() { return *this; } - const Self &operator*() const { return *this; } }; - typedef CaseIteratorT - ConstCaseIt; + typedef CaseHandleT + ConstCaseHandle; - class CaseIt : public CaseIteratorT { - typedef CaseIteratorT ParentTy; + class CaseHandle : public CaseHandleT { + friend class SwitchInst::CaseIteratorT; public: - CaseIt(const ParentTy &Src) : ParentTy(Src) {} - CaseIt(SwitchInst *SI, unsigned CaseNum) : ParentTy(SI, CaseNum) {} + CaseHandle(SwitchInst *SI, ptrdiff_t Index) : CaseHandleT(SI, Index) {} /// Sets the new value for current case. void setValue(ConstantInt *V) { @@ -3204,6 +3171,76 @@ } }; + template + class CaseIteratorT + : public iterator_facade_base, + std::random_access_iterator_tag, + CaseHandleT> { + typedef typename CaseHandleT::SwitchInstType SwitchInstT; + + protected: + CaseHandleT Case; + + public: + typedef CaseIteratorT Self; + + /// Default constructed iterator is in an invalid state until assigned to + /// a case for a particular switch. + CaseIteratorT() = default; + + /// Initializes case iterator for given SwitchInst and for given + /// case number. + CaseIteratorT(SwitchInstT *SI, unsigned CaseNum) : Case(SI, CaseNum) {} + + /// Initializes case iterator for given SwitchInst and for given + /// TerminatorInst's successor index. + static Self fromSuccessorIndex(SwitchInstT *SI, unsigned SuccessorIndex) { + assert(SuccessorIndex < SI->getNumSuccessors() && + "Successor index # out of range!"); + return SuccessorIndex != 0 ? Self(SI, SuccessorIndex - 1) + : Self(SI, DefaultPseudoIndex); + } + + /// Support converting to the const variant. This will be a no-op for const + /// variant. + operator CaseIteratorT() const { + return CaseIteratorT(Case.SI, Case.Index); + } + + Self &operator+=(ptrdiff_t N) { + // Check index correctness after addition. + // Note: Index == getNumCases() means end(). + assert(Case.Index + N >= 0 && + (unsigned)(Case.Index + N) <= Case.SI->getNumCases() && + "Case.Index out the number of cases."); + Case.Index += N; + return *this; + } + Self &operator-=(ptrdiff_t N) { + // Check index correctness after subtraction. + // Note: Case.Index == getNumCases() means end(). + assert(Case.Index - N >= 0 && + (unsigned)(Case.Index - N) <= Case.SI->getNumCases() && + "Case.Index out the number of cases."); + Case.Index -= N; + return *this; + } + ptrdiff_t operator-(const Self &RHS) const { + assert(Case.SI == RHS.Case.SI && "Incompatible operators."); + return Case.Index - RHS.Case.Index; + } + bool operator==(const Self &RHS) const { return Case == RHS.Case; } + bool operator<(const Self &RHS) const { + assert(Case.SI == RHS.Case.SI && "Incompatible operators."); + return Case.Index < RHS.Case.Index; + } + CaseHandleT &operator*() { return Case; } + const CaseHandleT &operator*() const { return Case; } + }; + + typedef CaseIteratorT CaseIt; + typedef CaseIteratorT ConstCaseIt; + static SwitchInst *Create(Value *Value, BasicBlock *Default, unsigned NumCases, Instruction *InsertBefore = nullptr) { @@ -3287,30 +3324,33 @@ /// default case iterator to indicate that it is handled by the default /// handler. CaseIt findCaseValue(const ConstantInt *C) { - for (CaseIt i = case_begin(), e = case_end(); i != e; ++i) - if (i.getCaseValue() == C) - return i; + for (CaseIt I = case_begin(), E = case_end(); I != E; ++I) + if (I->getCaseValue() == C) + return I; return case_default(); } ConstCaseIt findCaseValue(const ConstantInt *C) const { - for (ConstCaseIt i = case_begin(), e = case_end(); i != e; ++i) - if (i.getCaseValue() == C) - return i; + for (ConstCaseIt I = case_begin(), E = case_end(); I != E; ++I) + if (I->getCaseValue() == C) + return I; return case_default(); } /// Finds the unique case value for a given successor. Returns null if the /// successor is not found, not unique, or is the default case. ConstantInt *findCaseDest(BasicBlock *BB) { - if (BB == getDefaultDest()) return nullptr; + if (BB == getDefaultDest()) + return nullptr; ConstantInt *CI = nullptr; - for (CaseIt i = case_begin(), e = case_end(); i != e; ++i) { - if (i.getCaseSuccessor() == BB) { - if (CI) return nullptr; // Multiple cases lead to BB. - else CI = i.getCaseValue(); + for (auto Case : cases()) + if (Case.getCaseSuccessor() == BB) { + if (CI) + return nullptr; // Multiple cases lead to BB. + else + CI = Case.getCaseValue(); } - } + return CI; } @@ -3327,7 +3367,7 @@ /// This action invalidates iterators for all cases following the one removed, /// including the case_end() iterator. It returns an iterator for the next /// case. - CaseIt removeCase(CaseIt i); + CaseIt removeCase(CaseIt I); unsigned getNumSuccessors() const { return getNumOperands()/2; } BasicBlock *getSuccessor(unsigned idx) const { Index: lib/Analysis/InlineCost.cpp =================================================================== --- lib/Analysis/InlineCost.cpp +++ lib/Analysis/InlineCost.cpp @@ -1014,8 +1014,8 @@ // does not (yet) fire. SmallPtrSet SuccessorBlocks; SuccessorBlocks.insert(SI.getDefaultDest()); - for (auto I = SI.case_begin(), E = SI.case_end(); I != E; ++I) - SuccessorBlocks.insert(I.getCaseSuccessor()); + for (auto Case : SI.cases()) + SuccessorBlocks.insert(Case.getCaseSuccessor()); // Add cost corresponding to the number of distinct destinations. The first // we model as free because of fallthrough. Cost += (SuccessorBlocks.size() - 1) * InlineConstants::InstrCost; @@ -1379,7 +1379,7 @@ Value *Cond = SI->getCondition(); if (ConstantInt *SimpleCond = dyn_cast_or_null(SimplifiedValues.lookup(Cond))) { - BBWorklist.insert(SI->findCaseValue(SimpleCond).getCaseSuccessor()); + BBWorklist.insert(SI->findCaseValue(SimpleCond)->getCaseSuccessor()); continue; } } Index: lib/Analysis/LazyValueInfo.cpp =================================================================== --- lib/Analysis/LazyValueInfo.cpp +++ lib/Analysis/LazyValueInfo.cpp @@ -1430,14 +1430,14 @@ unsigned BitWidth = Val->getType()->getIntegerBitWidth(); ConstantRange EdgesVals(BitWidth, DefaultCase/*isFullSet*/); - for (SwitchInst::CaseIt i : SI->cases()) { - ConstantRange EdgeVal(i.getCaseValue()->getValue()); + for (auto Case : SI->cases()) { + ConstantRange EdgeVal(Case.getCaseValue()->getValue()); if (DefaultCase) { // It is possible that the default destination is the destination of // some cases. There is no need to perform difference for those cases. - if (i.getCaseSuccessor() != BBTo) + if (Case.getCaseSuccessor() != BBTo) EdgesVals = EdgesVals.difference(EdgeVal); - } else if (i.getCaseSuccessor() == BBTo) + } else if (Case.getCaseSuccessor() == BBTo) EdgesVals = EdgesVals.unionWith(EdgeVal); } Result = LVILatticeVal::getRange(std::move(EdgesVals)); Index: lib/Analysis/SparsePropagation.cpp =================================================================== --- lib/Analysis/SparsePropagation.cpp +++ lib/Analysis/SparsePropagation.cpp @@ -195,7 +195,7 @@ Succs.assign(TI.getNumSuccessors(), true); return; } - SwitchInst::CaseIt Case = SI.findCaseValue(cast(C)); + SwitchInst::CaseHandle Case = *SI.findCaseValue(cast(C)); Succs[Case.getSuccessorIndex()] = true; } Index: lib/Bitcode/Writer/BitcodeWriter.cpp =================================================================== --- lib/Bitcode/Writer/BitcodeWriter.cpp +++ lib/Bitcode/Writer/BitcodeWriter.cpp @@ -2578,7 +2578,7 @@ Vals.push_back(VE.getTypeID(SI.getCondition()->getType())); pushValue(SI.getCondition(), InstID, Vals); Vals.push_back(VE.getValueID(SI.getDefaultDest())); - for (SwitchInst::ConstCaseIt Case : SI.cases()) { + for (auto Case : SI.cases()) { Vals.push_back(VE.getValueID(Case.getCaseValue())); Vals.push_back(VE.getValueID(Case.getCaseSuccessor())); } Index: lib/CodeGen/CodeGenPrepare.cpp =================================================================== --- lib/CodeGen/CodeGenPrepare.cpp +++ lib/CodeGen/CodeGenPrepare.cpp @@ -5457,7 +5457,7 @@ auto *ExtInst = CastInst::Create(ExtType, Cond, NewType); ExtInst->insertBefore(SI); SI->setCondition(ExtInst); - for (SwitchInst::CaseIt Case : SI->cases()) { + for (auto Case : SI->cases()) { APInt NarrowConst = Case.getCaseValue()->getValue(); APInt WideConst = (ExtType == Instruction::ZExt) ? NarrowConst.zext(RegWidth) : NarrowConst.sext(RegWidth); Index: lib/ExecutionEngine/Interpreter/Execution.cpp =================================================================== --- lib/ExecutionEngine/Interpreter/Execution.cpp +++ lib/ExecutionEngine/Interpreter/Execution.cpp @@ -899,10 +899,10 @@ // Check to see if any of the cases match... BasicBlock *Dest = nullptr; - for (SwitchInst::CaseIt i = I.case_begin(), e = I.case_end(); i != e; ++i) { - GenericValue CaseVal = getOperandValue(i.getCaseValue(), SF); + for (auto Case : I.cases()) { + GenericValue CaseVal = getOperandValue(Case.getCaseValue(), SF); if (executeICMP_EQ(CondVal, CaseVal, ElTy).IntVal != 0) { - Dest = cast(i.getCaseSuccessor()); + Dest = cast(Case.getCaseSuccessor()); break; } } Index: lib/IR/AsmWriter.cpp =================================================================== --- lib/IR/AsmWriter.cpp +++ lib/IR/AsmWriter.cpp @@ -2909,12 +2909,11 @@ Out << ", "; writeOperand(SI.getDefaultDest(), true); Out << " ["; - for (SwitchInst::ConstCaseIt i = SI.case_begin(), e = SI.case_end(); - i != e; ++i) { + for (auto Case : SI.cases()) { Out << "\n "; - writeOperand(i.getCaseValue(), true); + writeOperand(Case.getCaseValue(), true); Out << ", "; - writeOperand(i.getCaseSuccessor(), true); + writeOperand(Case.getCaseSuccessor(), true); } Out << "\n ]"; } else if (isa(I)) { Index: lib/IR/Instructions.cpp =================================================================== --- lib/IR/Instructions.cpp +++ lib/IR/Instructions.cpp @@ -3655,16 +3655,16 @@ // Initialize some new operands. assert(OpNo+1 < ReservedSpace && "Growing didn't work!"); setNumHungOffUseOperands(OpNo+2); - CaseIt Case(this, NewCaseIdx); + CaseHandle Case(this, NewCaseIdx); Case.setValue(OnVal); Case.setSuccessor(Dest); } /// removeCase - This method removes the specified case and its successor /// from the switch instruction. -SwitchInst::CaseIt SwitchInst::removeCase(CaseIt i) { - unsigned idx = i.getCaseIndex(); - +SwitchInst::CaseIt SwitchInst::removeCase(CaseIt I) { + unsigned idx = I->getCaseIndex(); + assert(2 + idx*2 < getNumOperands() && "Case index out of range!!!"); unsigned NumOps = getNumOperands(); Index: lib/Transforms/Coroutines/CoroSplit.cpp =================================================================== --- lib/Transforms/Coroutines/CoroSplit.cpp +++ lib/Transforms/Coroutines/CoroSplit.cpp @@ -185,9 +185,9 @@ coro::Shape &Shape, SwitchInst *Switch, bool IsDestroy) { assert(Shape.HasFinalSuspend); - auto FinalCase = --Switch->case_end(); - BasicBlock *ResumeBB = FinalCase.getCaseSuccessor(); - Switch->removeCase(FinalCase); + auto FinalCaseIt = std::prev(Switch->case_end()); + BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor(); + Switch->removeCase(FinalCaseIt); if (IsDestroy) { BasicBlock *OldSwitchBB = Switch->getParent(); auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch"); Index: lib/Transforms/InstCombine/InstructionCombining.cpp =================================================================== --- lib/Transforms/InstCombine/InstructionCombining.cpp +++ lib/Transforms/InstCombine/InstructionCombining.cpp @@ -2240,11 +2240,11 @@ ConstantInt *AddRHS; if (match(Cond, m_Add(m_Value(Op0), m_ConstantInt(AddRHS)))) { // Change 'switch (X+4) case 1:' into 'switch (X) case -3'. - for (SwitchInst::CaseIt CaseIter : SI.cases()) { - Constant *NewCase = ConstantExpr::getSub(CaseIter.getCaseValue(), AddRHS); + for (auto Case : SI.cases()) { + Constant *NewCase = ConstantExpr::getSub(Case.getCaseValue(), AddRHS); assert(isa(NewCase) && "Result of expression should be constant"); - CaseIter.setValue(cast(NewCase)); + Case.setValue(cast(NewCase)); } SI.setCondition(Op0); return &SI; @@ -2276,9 +2276,9 @@ Value *NewCond = Builder->CreateTrunc(Cond, Ty, "trunc"); SI.setCondition(NewCond); - for (SwitchInst::CaseIt CaseIter : SI.cases()) { - APInt TruncatedCase = CaseIter.getCaseValue()->getValue().trunc(NewWidth); - CaseIter.setValue(ConstantInt::get(SI.getContext(), TruncatedCase)); + for (auto Case : SI.cases()) { + APInt TruncatedCase = Case.getCaseValue()->getValue().trunc(NewWidth); + Case.setValue(ConstantInt::get(SI.getContext(), TruncatedCase)); } return &SI; } @@ -3058,17 +3058,7 @@ } } else if (SwitchInst *SI = dyn_cast(TI)) { if (ConstantInt *Cond = dyn_cast(SI->getCondition())) { - // See if this is an explicit destination. - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); - i != e; ++i) - if (i.getCaseValue() == Cond) { - BasicBlock *ReachableBB = i.getCaseSuccessor(); - Worklist.push_back(ReachableBB); - continue; - } - - // Otherwise it is the default destination. - Worklist.push_back(SI->getDefaultDest()); + Worklist.push_back(SI->findCaseValue(Cond)->getCaseSuccessor()); continue; } } Index: lib/Transforms/Scalar/CorrelatedValuePropagation.cpp =================================================================== --- lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -236,7 +236,7 @@ // removing a case doesn't cause trouble for the iteration. bool Changed = false; for (auto CI = SI->case_begin(), CE = SI->case_end(); CI != CE;) { - ConstantInt *Case = CI.getCaseValue(); + ConstantInt *Case = CI->getCaseValue(); // Check to see if the switch condition is equal to/not equal to the case // value on every incoming edge, equal/not equal being the same each time. @@ -269,7 +269,7 @@ if (State == LazyValueInfo::False) { // This case never fires - remove it. - CI.getCaseSuccessor()->removePredecessor(BB); + CI->getCaseSuccessor()->removePredecessor(BB); CI = SI->removeCase(CI); CE = SI->case_end(); Index: lib/Transforms/Scalar/GVN.cpp =================================================================== --- lib/Transforms/Scalar/GVN.cpp +++ lib/Transforms/Scalar/GVN.cpp @@ -1761,11 +1761,11 @@ for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i) { - BasicBlock *Dst = i.getCaseSuccessor(); + BasicBlock *Dst = i->getCaseSuccessor(); // If there is only a single edge, propagate the case value into it. if (SwitchEdges.lookup(Dst) == 1) { BasicBlockEdge E(Parent, Dst); - Changed |= propagateEquality(SwitchCond, i.getCaseValue(), E, true); + Changed |= propagateEquality(SwitchCond, i->getCaseValue(), E, true); } } return Changed; Index: lib/Transforms/Scalar/JumpThreading.cpp =================================================================== --- lib/Transforms/Scalar/JumpThreading.cpp +++ lib/Transforms/Scalar/JumpThreading.cpp @@ -1269,7 +1269,7 @@ else if (BranchInst *BI = dyn_cast(BB->getTerminator())) DestBB = BI->getSuccessor(cast(Val)->isZero()); else if (SwitchInst *SI = dyn_cast(BB->getTerminator())) { - DestBB = SI->findCaseValue(cast(Val)).getCaseSuccessor(); + DestBB = SI->findCaseValue(cast(Val))->getCaseSuccessor(); } else { assert(isa(BB->getTerminator()) && "Unexpected terminator"); Index: lib/Transforms/Scalar/LoopUnrollPass.cpp =================================================================== --- lib/Transforms/Scalar/LoopUnrollPass.cpp +++ lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -509,7 +509,7 @@ KnownSucc = SI->getSuccessor(0); else if (ConstantInt *SimpleCondVal = dyn_cast(SimpleCond)) - KnownSucc = SI->findCaseValue(SimpleCondVal).getCaseSuccessor(); + KnownSucc = SI->findCaseValue(SimpleCondVal)->getCaseSuccessor(); } } if (KnownSucc) { Index: lib/Transforms/Scalar/LoopUnswitch.cpp =================================================================== --- lib/Transforms/Scalar/LoopUnswitch.cpp +++ lib/Transforms/Scalar/LoopUnswitch.cpp @@ -709,9 +709,8 @@ // Do not process same value again and again. // At this point we have some cases already unswitched and // some not yet unswitched. Let's find the first not yet unswitched one. - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); - i != e; ++i) { - Constant *UnswitchValCandidate = i.getCaseValue(); + for (auto Case : SI->cases()) { + Constant *UnswitchValCandidate = Case.getCaseValue(); if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) { UnswitchVal = UnswitchValCandidate; break; @@ -987,7 +986,7 @@ if (!Cond) break; // Find the target block we are definitely going to. - CurrentBB = SI->findCaseValue(Cond).getCaseSuccessor(); + CurrentBB = SI->findCaseValue(Cond)->getCaseSuccessor(); } else { // We do not understand these terminator instructions. break; @@ -1051,13 +1050,12 @@ // this. // Note that we can't trivially unswitch on the default case or // on already unswitched cases. - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); - i != e; ++i) { + for (auto Case : SI->cases()) { BasicBlock *LoopExitCandidate; - if ((LoopExitCandidate = isTrivialLoopExitBlock(currentLoop, - i.getCaseSuccessor()))) { + if ((LoopExitCandidate = + isTrivialLoopExitBlock(currentLoop, Case.getCaseSuccessor()))) { // Okay, we found a trivial case, remember the value that is trivial. - ConstantInt *CaseVal = i.getCaseValue(); + ConstantInt *CaseVal = Case.getCaseValue(); // Check that it was not unswitched before, since already unswitched // trivial vals are looks trivial too. @@ -1361,9 +1359,11 @@ // NOTE: if a case value for the switch is unswitched out, we record it // after the unswitch finishes. We can not record it here as the switch // is not a direct user of the partial LIV. - SwitchInst::CaseIt DeadCase = SI->findCaseValue(cast(Val)); + SwitchInst::CaseHandle DeadCase = + *SI->findCaseValue(cast(Val)); // Default case is live for multiple values. - if (DeadCase == SI->case_default()) continue; + if (DeadCase == *SI->case_default()) + continue; // Found a dead case value. Don't remove PHI nodes in the // successor if they become single-entry, those PHI nodes may Index: lib/Transforms/Scalar/LowerExpectIntrinsic.cpp =================================================================== --- lib/Transforms/Scalar/LowerExpectIntrinsic.cpp +++ lib/Transforms/Scalar/LowerExpectIntrinsic.cpp @@ -67,11 +67,11 @@ if (!ExpectedValue) return false; - SwitchInst::CaseIt Case = SI.findCaseValue(ExpectedValue); + SwitchInst::CaseHandle Case = *SI.findCaseValue(ExpectedValue); unsigned n = SI.getNumCases(); // +1 for default case. SmallVector Weights(n + 1, UnlikelyBranchWeight); - if (Case == SI.case_default()) + if (Case == *SI.case_default()) Weights[0] = LikelyBranchWeight; else Weights[Case.getCaseIndex() + 1] = LikelyBranchWeight; Index: lib/Transforms/Scalar/NewGVN.cpp =================================================================== --- lib/Transforms/Scalar/NewGVN.cpp +++ lib/Transforms/Scalar/NewGVN.cpp @@ -2002,8 +2002,8 @@ if (CondEvaluated && isa(CondEvaluated)) { auto *CondVal = cast(CondEvaluated); // We should be able to get case value for this. - auto CaseVal = SI->findCaseValue(CondVal); - if (CaseVal.getCaseSuccessor() == SI->getDefaultDest()) { + auto Case = *SI->findCaseValue(CondVal); + if (Case.getCaseSuccessor() == SI->getDefaultDest()) { // We proved the value is outside of the range of the case. // We can't do anything other than mark the default dest as reachable, // and go home. @@ -2011,7 +2011,7 @@ return; } // Now get where it goes and mark it reachable. - BasicBlock *TargetBlock = CaseVal.getCaseSuccessor(); + BasicBlock *TargetBlock = Case.getCaseSuccessor(); updateReachableEdge(B, TargetBlock); } else { for (unsigned i = 0, e = SI->getNumSuccessors(); i != e; ++i) { Index: lib/Transforms/Scalar/SCCP.cpp =================================================================== --- lib/Transforms/Scalar/SCCP.cpp +++ lib/Transforms/Scalar/SCCP.cpp @@ -596,7 +596,7 @@ return; } - Succs[SI->findCaseValue(CI).getSuccessorIndex()] = true; + Succs[SI->findCaseValue(CI)->getSuccessorIndex()] = true; return; } @@ -653,7 +653,7 @@ if (!CI) return !SCValue.isUnknown(); - return SI->findCaseValue(CI).getCaseSuccessor() == To; + return SI->findCaseValue(CI)->getCaseSuccessor() == To; } // Just mark all destinations executable! @@ -1491,12 +1491,12 @@ // If the input to SCCP is actually switch on undef, fix the undef to // the first constant. if (isa(SI->getCondition())) { - SI->setCondition(SI->case_begin().getCaseValue()); - markEdgeExecutable(&BB, SI->case_begin().getCaseSuccessor()); + SI->setCondition(SI->case_begin()->getCaseValue()); + markEdgeExecutable(&BB, SI->case_begin()->getCaseSuccessor()); return true; } - markForcedConstant(SI->getCondition(), SI->case_begin().getCaseValue()); + markForcedConstant(SI->getCondition(), SI->case_begin()->getCaseValue()); return true; } } Index: lib/Transforms/Utils/CloneFunction.cpp =================================================================== --- lib/Transforms/Utils/CloneFunction.cpp +++ lib/Transforms/Utils/CloneFunction.cpp @@ -352,7 +352,7 @@ Cond = dyn_cast_or_null(V); } if (Cond) { // Constant fold to uncond branch! - SwitchInst::ConstCaseIt Case = SI->findCaseValue(Cond); + SwitchInst::ConstCaseHandle Case = *SI->findCaseValue(Cond); BasicBlock *Dest = const_cast(Case.getCaseSuccessor()); VMap[OldTI] = BranchInst::Create(Dest, NewBB); ToClone.push_back(Dest); Index: lib/Transforms/Utils/Evaluator.cpp =================================================================== --- lib/Transforms/Utils/Evaluator.cpp +++ lib/Transforms/Utils/Evaluator.cpp @@ -487,7 +487,7 @@ ConstantInt *Val = dyn_cast(getVal(SI->getCondition())); if (!Val) return false; // Cannot determine. - NextBB = SI->findCaseValue(Val).getCaseSuccessor(); + NextBB = SI->findCaseValue(Val)->getCaseSuccessor(); } else if (IndirectBrInst *IBI = dyn_cast(CurInst)) { Value *Val = getVal(IBI->getAddress())->stripPointerCasts(); if (BlockAddress *BA = dyn_cast(Val)) Index: lib/Transforms/Utils/Local.cpp =================================================================== --- lib/Transforms/Utils/Local.cpp +++ lib/Transforms/Utils/Local.cpp @@ -126,20 +126,20 @@ // If the default is unreachable, ignore it when searching for TheOnlyDest. if (isa(DefaultDest->getFirstNonPHIOrDbg()) && SI->getNumCases() > 0) { - TheOnlyDest = SI->case_begin().getCaseSuccessor(); + TheOnlyDest = SI->case_begin()->getCaseSuccessor(); } // Figure out which case it goes to. for (auto i = SI->case_begin(), e = SI->case_end(); i != e;) { // Found case matching a constant operand? - if (i.getCaseValue() == CI) { - TheOnlyDest = i.getCaseSuccessor(); + if (i->getCaseValue() == CI) { + TheOnlyDest = i->getCaseSuccessor(); break; } // Check to see if this branch is going to the same place as the default // dest. If so, eliminate it as an explicit compare. - if (i.getCaseSuccessor() == DefaultDest) { + if (i->getCaseSuccessor() == DefaultDest) { MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); unsigned NCases = SI->getNumCases(); // Fold the case metadata into the default if there will be any branches @@ -153,7 +153,7 @@ Weights.push_back(CI->getValue().getZExtValue()); } // Merge weight of this case to the default weight. - unsigned idx = i.getCaseIndex(); + unsigned idx = i->getCaseIndex(); Weights[0] += Weights[idx+1]; // Remove weight for this case. std::swap(Weights[idx+1], Weights.back()); @@ -172,7 +172,8 @@ // Otherwise, check to see if the switch only branches to one destination. // We do this by reseting "TheOnlyDest" to null when we find two non-equal // destinations. - if (i.getCaseSuccessor() != TheOnlyDest) TheOnlyDest = nullptr; + if (i->getCaseSuccessor() != TheOnlyDest) + TheOnlyDest = nullptr; // Increment this iterator as we haven't removed the case. ++i; @@ -211,7 +212,7 @@ if (SI->getNumCases() == 1) { // Otherwise, we can fold this switch into a conditional branch // instruction if it has only one non-default destination. - SwitchInst::CaseIt FirstCase = SI->case_begin(); + auto FirstCase = *SI->case_begin(); Value *Cond = Builder.CreateICmpEQ(SI->getCondition(), FirstCase.getCaseValue(), "cond"); Index: lib/Transforms/Utils/LowerSwitch.cpp =================================================================== --- lib/Transforms/Utils/LowerSwitch.cpp +++ lib/Transforms/Utils/LowerSwitch.cpp @@ -356,10 +356,10 @@ unsigned numCmps = 0; // Start with "simple" cases - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i) - Cases.push_back(CaseRange(i.getCaseValue(), i.getCaseValue(), - i.getCaseSuccessor())); - + for (auto Case : SI->cases()) + Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(), + Case.getCaseSuccessor())); + std::sort(Cases.begin(), Cases.end(), CaseCmp()); // Merge case into clusters Index: lib/Transforms/Utils/SimplifyCFG.cpp =================================================================== --- lib/Transforms/Utils/SimplifyCFG.cpp +++ lib/Transforms/Utils/SimplifyCFG.cpp @@ -714,10 +714,9 @@ TerminatorInst *TI, std::vector &Cases) { if (SwitchInst *SI = dyn_cast(TI)) { Cases.reserve(SI->getNumCases()); - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; - ++i) - Cases.push_back( - ValueEqualityComparisonCase(i.getCaseValue(), i.getCaseSuccessor())); + for (auto Case : SI->cases()) + Cases.push_back(ValueEqualityComparisonCase(Case.getCaseValue(), + Case.getCaseSuccessor())); return SI->getDefaultDest(); } @@ -850,12 +849,12 @@ } for (SwitchInst::CaseIt i = SI->case_end(), e = SI->case_begin(); i != e;) { --i; - if (DeadCases.count(i.getCaseValue())) { + if (DeadCases.count(i->getCaseValue())) { if (HasWeight) { - std::swap(Weights[i.getCaseIndex() + 1], Weights.back()); + std::swap(Weights[i->getCaseIndex() + 1], Weights.back()); Weights.pop_back(); } - i.getCaseSuccessor()->removePredecessor(TI->getParent()); + i->getCaseSuccessor()->removePredecessor(TI->getParent()); SI->removeCase(i); } } @@ -3444,8 +3443,8 @@ // Find the relevant condition and destinations. Value *Condition = Select->getCondition(); - BasicBlock *TrueBB = SI->findCaseValue(TrueVal).getCaseSuccessor(); - BasicBlock *FalseBB = SI->findCaseValue(FalseVal).getCaseSuccessor(); + BasicBlock *TrueBB = SI->findCaseValue(TrueVal)->getCaseSuccessor(); + BasicBlock *FalseBB = SI->findCaseValue(FalseVal)->getCaseSuccessor(); // Get weight for TrueBB and FalseBB. uint32_t TrueWeight = 0, FalseWeight = 0; @@ -3455,9 +3454,9 @@ GetBranchWeights(SI, Weights); if (Weights.size() == 1 + SI->getNumCases()) { TrueWeight = - (uint32_t)Weights[SI->findCaseValue(TrueVal).getSuccessorIndex()]; + (uint32_t)Weights[SI->findCaseValue(TrueVal)->getSuccessorIndex()]; FalseWeight = - (uint32_t)Weights[SI->findCaseValue(FalseVal).getSuccessorIndex()]; + (uint32_t)Weights[SI->findCaseValue(FalseVal)->getSuccessorIndex()]; } } @@ -4160,7 +4159,7 @@ } } else if (auto *SI = dyn_cast(TI)) { for (auto i = SI->case_begin(), e = SI->case_end(); i != e;) { - if (i.getCaseSuccessor() != BB) { + if (i->getCaseSuccessor() != BB) { ++i; continue; } @@ -4251,18 +4250,18 @@ SmallVector CasesA; SmallVector CasesB; - for (SwitchInst::CaseIt I : SI->cases()) { - BasicBlock *Dest = I.getCaseSuccessor(); + for (auto Case : SI->cases()) { + BasicBlock *Dest = Case.getCaseSuccessor(); if (!DestA) DestA = Dest; if (Dest == DestA) { - CasesA.push_back(I.getCaseValue()); + CasesA.push_back(Case.getCaseValue()); continue; } if (!DestB) DestB = Dest; if (Dest == DestB) { - CasesB.push_back(I.getCaseValue()); + CasesB.push_back(Case.getCaseValue()); continue; } return false; // More than two destinations. @@ -4412,17 +4411,17 @@ // Remove dead cases from the switch. for (ConstantInt *DeadCase : DeadCases) { - SwitchInst::CaseIt Case = SI->findCaseValue(DeadCase); - assert(Case != SI->case_default() && + SwitchInst::CaseIt CaseI = SI->findCaseValue(DeadCase); + assert(CaseI != SI->case_default() && "Case was not found. Probably mistake in DeadCases forming."); if (HasWeight) { - std::swap(Weights[Case.getCaseIndex() + 1], Weights.back()); + std::swap(Weights[CaseI->getCaseIndex() + 1], Weights.back()); Weights.pop_back(); } // Prune unused values from PHI nodes. - Case.getCaseSuccessor()->removePredecessor(SI->getParent()); - SI->removeCase(Case); + CaseI->getCaseSuccessor()->removePredecessor(SI->getParent()); + SI->removeCase(CaseI); } if (HasWeight && Weights.size() >= 2) { SmallVector MDWeights(Weights.begin(), Weights.end()); @@ -4476,10 +4475,9 @@ typedef DenseMap> ForwardingNodesMap; ForwardingNodesMap ForwardingNodes; - for (SwitchInst::CaseIt I = SI->case_begin(), E = SI->case_end(); I != E; - ++I) { - ConstantInt *CaseValue = I.getCaseValue(); - BasicBlock *CaseDest = I.getCaseSuccessor(); + for (auto Case : SI->cases()) { + ConstantInt *CaseValue = Case.getCaseValue(); + BasicBlock *CaseDest = Case.getCaseSuccessor(); int PhiIndex; PHINode *PHI = @@ -5214,8 +5212,8 @@ // common destination, as well as the min and max case values. assert(SI->case_begin() != SI->case_end()); SwitchInst::CaseIt CI = SI->case_begin(); - ConstantInt *MinCaseVal = CI.getCaseValue(); - ConstantInt *MaxCaseVal = CI.getCaseValue(); + ConstantInt *MinCaseVal = CI->getCaseValue(); + ConstantInt *MaxCaseVal = CI->getCaseValue(); BasicBlock *CommonDest = nullptr; typedef SmallVector, 4> ResultListTy; @@ -5225,7 +5223,7 @@ SmallVector PHIs; for (SwitchInst::CaseIt E = SI->case_end(); CI != E; ++CI) { - ConstantInt *CaseVal = CI.getCaseValue(); + ConstantInt *CaseVal = CI->getCaseValue(); if (CaseVal->getValue().slt(MinCaseVal->getValue())) MinCaseVal = CaseVal; if (CaseVal->getValue().sgt(MaxCaseVal->getValue())) @@ -5234,7 +5232,7 @@ // Resulting value at phi nodes for this case value. typedef SmallVector, 4> ResultsTy; ResultsTy Results; - if (!GetCaseResults(SI, CaseVal, CI.getCaseSuccessor(), &CommonDest, + if (!GetCaseResults(SI, CaseVal, CI->getCaseSuccessor(), &CommonDest, Results, DL, TTI)) return false; @@ -5515,11 +5513,10 @@ auto *Rot = Builder.CreateOr(LShr, Shl); SI->replaceUsesOfWith(SI->getCondition(), Rot); - for (SwitchInst::CaseIt C = SI->case_begin(), E = SI->case_end(); C != E; - ++C) { - auto *Orig = C.getCaseValue(); + for (auto Case : SI->cases()) { + auto *Orig = Case.getCaseValue(); auto Sub = Orig->getValue() - APInt(Ty->getBitWidth(), Base); - C.setValue( + Case.setValue( cast(ConstantInt::get(Ty, Sub.lshr(ShiftC->getValue())))); } return true; Index: tools/llvm-diff/DifferenceEngine.cpp =================================================================== --- tools/llvm-diff/DifferenceEngine.cpp +++ tools/llvm-diff/DifferenceEngine.cpp @@ -315,17 +315,15 @@ bool Difference = false; DenseMap LCases; - - for (SwitchInst::CaseIt I = LI->case_begin(), E = LI->case_end(); - I != E; ++I) - LCases[I.getCaseValue()] = I.getCaseSuccessor(); - - for (SwitchInst::CaseIt I = RI->case_begin(), E = RI->case_end(); - I != E; ++I) { - ConstantInt *CaseValue = I.getCaseValue(); + for (auto Case : LI->cases()) + LCases[Case.getCaseValue()] = Case.getCaseSuccessor(); + + for (auto Case : RI->cases()) { + ConstantInt *CaseValue = Case.getCaseValue(); BasicBlock *LCase = LCases[CaseValue]; if (LCase) { - if (TryUnify) tryUnify(LCase, I.getCaseSuccessor()); + if (TryUnify) + tryUnify(LCase, Case.getCaseSuccessor()); LCases.erase(CaseValue); } else if (Complain || !Difference) { if (Complain) Index: unittests/IR/InstructionsTest.cpp =================================================================== --- unittests/IR/InstructionsTest.cpp +++ unittests/IR/InstructionsTest.cpp @@ -677,5 +677,68 @@ delete GEPI; } +TEST(InstructionsTest, SwitchInst) { + LLVMContext C; + + std::unique_ptr BB1, BB2, BB3; + BB1.reset(BasicBlock::Create(C)); + BB2.reset(BasicBlock::Create(C)); + BB3.reset(BasicBlock::Create(C)); + + // We create block 0 after the others so that it gets destroyed first and + // clears the uses of the other basic blocks. + std::unique_ptr BB0(BasicBlock::Create(C)); + + auto *Int32Ty = Type::getInt32Ty(C); + + SwitchInst *SI = + SwitchInst::Create(UndefValue::get(Int32Ty), BB0.get(), 3, BB0.get()); + SI->addCase(ConstantInt::get(Int32Ty, 1), BB1.get()); + SI->addCase(ConstantInt::get(Int32Ty, 2), BB2.get()); + SI->addCase(ConstantInt::get(Int32Ty, 3), BB3.get()); + + auto CI = SI->case_begin(); + ASSERT_NE(CI, SI->case_end()); + EXPECT_EQ(1, CI->getCaseValue()->getSExtValue()); + EXPECT_EQ(BB1.get(), CI->getCaseSuccessor()); + EXPECT_EQ(2, (CI + 1)->getCaseValue()->getSExtValue()); + EXPECT_EQ(BB2.get(), (CI + 1)->getCaseSuccessor()); + EXPECT_EQ(3, (CI + 2)->getCaseValue()->getSExtValue()); + EXPECT_EQ(BB3.get(), (CI + 2)->getCaseSuccessor()); + EXPECT_EQ(CI + 1, std::next(CI)); + EXPECT_EQ(CI + 2, std::next(CI, 2)); + EXPECT_EQ(CI + 3, std::next(CI, 3)); + EXPECT_EQ(SI->case_end(), CI + 3); + EXPECT_EQ(0, CI - CI); + EXPECT_EQ(1, (CI + 1) - CI); + EXPECT_EQ(2, (CI + 2) - CI); + EXPECT_EQ(3, SI->case_end() - CI); + EXPECT_EQ(3, std::distance(CI, SI->case_end())); + + auto CCI = const_cast(SI)->case_begin(); + SwitchInst::ConstCaseIt CCE = SI->case_end(); + ASSERT_NE(CCI, SI->case_end()); + EXPECT_EQ(1, CCI->getCaseValue()->getSExtValue()); + EXPECT_EQ(BB1.get(), CCI->getCaseSuccessor()); + EXPECT_EQ(2, (CCI + 1)->getCaseValue()->getSExtValue()); + EXPECT_EQ(BB2.get(), (CCI + 1)->getCaseSuccessor()); + EXPECT_EQ(3, (CCI + 2)->getCaseValue()->getSExtValue()); + EXPECT_EQ(BB3.get(), (CCI + 2)->getCaseSuccessor()); + EXPECT_EQ(CCI + 1, std::next(CCI)); + EXPECT_EQ(CCI + 2, std::next(CCI, 2)); + EXPECT_EQ(CCI + 3, std::next(CCI, 3)); + EXPECT_EQ(CCE, CCI + 3); + EXPECT_EQ(0, CCI - CCI); + EXPECT_EQ(1, (CCI + 1) - CCI); + EXPECT_EQ(2, (CCI + 2) - CCI); + EXPECT_EQ(3, CCE - CCI); + EXPECT_EQ(3, std::distance(CCI, CCE)); + + // Make sure that the const iterator is compatible with a const auto ref. + const auto &Handle = *CCI; + EXPECT_EQ(1, Handle.getCaseValue()->getSExtValue()); + EXPECT_EQ(BB1.get(), Handle.getCaseSuccessor()); +} + } // end anonymous namespace } // end namespace llvm