diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -538,7 +538,7 @@ early_inc_iterator_impl(WrappedIteratorT I) : BaseT(I) {} using BaseT::operator*; - typename BaseT::reference operator*() { + decltype(*std::declval()) operator*() { #if LLVM_ENABLE_ABI_BREAKING_CHECKS assert(!IsEarlyIncremented && "Cannot dereference twice!"); IsEarlyIncremented = true; diff --git a/llvm/lib/Analysis/MemoryBuiltins.cpp b/llvm/lib/Analysis/MemoryBuiltins.cpp --- a/llvm/lib/Analysis/MemoryBuiltins.cpp +++ b/llvm/lib/Analysis/MemoryBuiltins.cpp @@ -378,9 +378,8 @@ unsigned NumOfBitCastUses = 0; // Determine if CallInst has a bitcast use. - for (Value::const_user_iterator UI = CI->user_begin(), E = CI->user_end(); - UI != E;) - if (const BitCastInst *BCI = dyn_cast(*UI++)) { + for (const User *U : CI->users()) + if (const BitCastInst *BCI = dyn_cast(U)) { MallocType = cast(BCI->getDestTy()); NumOfBitCastUses++; } diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp --- a/llvm/lib/IR/AutoUpgrade.cpp +++ b/llvm/lib/IR/AutoUpgrade.cpp @@ -3894,8 +3894,8 @@ if (UpgradeIntrinsicFunction(F, NewFn)) { // Replace all users of the old function with the new function or new // instructions. This is not a range loop because the call is deleted. - for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE; ) - if (CallInst *CI = dyn_cast(*UI++)) + for (User *U : make_early_inc_range(F->users())) + if (CallInst *CI = dyn_cast(U)) UpgradeIntrinsicCall(CI, NewFn); // Remove old function, no longer used, from the module. @@ -4031,8 +4031,8 @@ Function *NewFn = llvm::Intrinsic::getDeclaration(&M, IntrinsicFunc); - for (auto I = Fn->user_begin(), E = Fn->user_end(); I != E;) { - CallInst *CI = dyn_cast(*I++); + for (User *U : make_early_inc_range(Fn->users())) { + CallInst *CI = dyn_cast(U); if (!CI || CI->getCalledFunction() != Fn) continue; diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp --- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -160,8 +160,8 @@ // In this table, we will track which indices are loaded from the argument // (where direct loads are tracked as no indices). ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; - for (auto Iter = I->user_begin(), End = I->user_end(); Iter != End;) { - Instruction *UI = cast(*Iter++); + for (User *U : make_early_inc_range(I->users())) { + Instruction *UI = cast(U); Type *SrcTy; if (LoadInst *L = dyn_cast(UI)) SrcTy = L->getType(); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -2500,10 +2500,7 @@ Instruction *RetVal = nullptr; for (auto *OldPN : OldPhiNodes) { PHINode *NewPN = NewPNodes[OldPN]; - for (auto It = OldPN->user_begin(), End = OldPN->user_end(); It != End; ) { - User *V = *It; - // We may remove this user, advance to avoid iterator invalidation. - ++It; + for (User *V : make_early_inc_range(OldPN->users())) { if (auto *SI = dyn_cast(V)) { assert(SI->isSimple() && SI->getOperand(0) == OldPN); Builder.SetInsertPoint(SI); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -4751,8 +4751,7 @@ // mul.with.overflow and adjust properly mask/size. if (MulVal->hasNUsesOrMore(2)) { Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value"); - for (auto UI = MulVal->user_begin(), UE = MulVal->user_end(); UI != UE;) { - User *U = *UI++; + for (User *U : make_early_inc_range(MulVal->users())) { if (U == &I || U == OtherVal) continue; if (TruncInst *TI = dyn_cast(U)) { diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1175,8 +1175,8 @@ } } - for (auto UI = PN->user_begin(), E = PN->user_end(); UI != E;) { - Instruction *User = cast(*UI++); + for (User *U : make_early_inc_range(PN->users())) { + Instruction *User = cast(U); if (User == &I) continue; replaceInstUsesWith(*User, NewPN); eraseInstFromFunction(*User); diff --git a/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp --- a/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -129,8 +129,8 @@ // As we scan the uses of the alloca instruction, keep track of stores, // and decide whether all of the loads and stores to the alloca are within // the same basic block. - for (auto UI = AI->user_begin(), E = AI->user_end(); UI != E;) { - Instruction *User = cast(*UI++); + for (User *U : AI->users()) { + Instruction *User = cast(U); if (StoreInst *SI = dyn_cast(User)) { // Remember the basic blocks which define new values for the alloca @@ -366,8 +366,8 @@ // Clear out UsingBlocks. We will reconstruct it here if needed. Info.UsingBlocks.clear(); - for (auto UI = AI->user_begin(), E = AI->user_end(); UI != E;) { - Instruction *UserInst = cast(*UI++); + for (User *U : make_early_inc_range(AI->users())) { + Instruction *UserInst = cast(U); if (UserInst == OnlyStore) continue; LoadInst *LI = cast(UserInst); @@ -480,8 +480,8 @@ // Walk all of the loads from this alloca, replacing them with the nearest // store above them, if any. - for (auto UI = AI->user_begin(), E = AI->user_end(); UI != E;) { - LoadInst *LI = dyn_cast(*UI++); + for (User *U : make_early_inc_range(AI->users())) { + LoadInst *LI = dyn_cast(U); if (!LI) continue; diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -451,6 +451,79 @@ EXPECT_EQ(EIR.end(), I); } +// A custom iterator that returns a pointer when dereferenced. This is used to +// test make_early_inc_range with iterators that do not return a reference on +// dereferencing. +struct CustomPointerIterator + : public iterator_adaptor_base::iterator, + std::forward_iterator_tag> { + using base_type = + iterator_adaptor_base::iterator, + std::forward_iterator_tag>; + + explicit CustomPointerIterator(std::list::iterator I) : base_type(I) {} + + // Retrieve a pointer to the current int. + int *operator*() const { return &*base_type::wrapped(); } +}; + +// Make sure make_early_inc_range works with iterators that do not return a +// reference on dereferencing. The test is similar to EarlyIncrementTest, but +// uses CustomPointerIterator. +TEST(STLExtrasTest, EarlyIncrementTestCustomPointerIterator) { + std::list L = {1, 2, 3, 4}; + + auto CustomRange = make_range(CustomPointerIterator(L.begin()), + CustomPointerIterator(L.end())); + auto EIR = make_early_inc_range(CustomRange); + + auto I = EIR.begin(); + auto EI = EIR.end(); + EXPECT_NE(I, EI); + + EXPECT_EQ(&*L.begin(), *I); +#if LLVM_ENABLE_ABI_BREAKING_CHECKS +#ifndef NDEBUG + // Repeated dereferences are not allowed. + EXPECT_DEATH(*I, "Cannot dereference"); + // Comparison after dereference is not allowed. + EXPECT_DEATH((void)(I == EI), "Cannot compare"); + EXPECT_DEATH((void)(I != EI), "Cannot compare"); +#endif +#endif + + ++I; + EXPECT_NE(I, EI); +#if LLVM_ENABLE_ABI_BREAKING_CHECKS +#ifndef NDEBUG + // You cannot increment prior to dereference. + EXPECT_DEATH(++I, "Cannot increment"); +#endif +#endif + EXPECT_EQ(&*std::next(L.begin()), *I); +#if LLVM_ENABLE_ABI_BREAKING_CHECKS +#ifndef NDEBUG + // Repeated dereferences are not allowed. + EXPECT_DEATH(*I, "Cannot dereference"); +#endif +#endif + + // Inserting shouldn't break anything. We should be able to keep dereferencing + // the currrent iterator and increment. The increment to go to the "next" + // iterator from before we inserted. + L.insert(std::next(L.begin(), 2), -1); + ++I; + EXPECT_EQ(&*std::next(L.begin(), 3), *I); + + // Erasing the front including the current doesn't break incrementing. + L.erase(L.begin(), std::prev(L.end())); + ++I; + EXPECT_EQ(&*L.begin(), *I); + ++I; + EXPECT_EQ(EIR.end(), I); +} + TEST(STLExtrasTest, splat) { std::vector V; EXPECT_FALSE(is_splat(V));