diff --git a/llvm/include/llvm/Analysis/AssumptionCache.h b/llvm/include/llvm/Analysis/AssumptionCache.h --- a/llvm/include/llvm/Analysis/AssumptionCache.h +++ b/llvm/include/llvm/Analysis/AssumptionCache.h @@ -18,7 +18,9 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" @@ -44,6 +46,22 @@ /// llvm.assume. enum : unsigned { ExprResultIdx = std::numeric_limits::max() }; + /// Callback handle to ensure we do not have dangling pointers to llvm.assume + /// calls in our cache. + class AssumeHandle final : public CallbackVH { + AssumptionCache *AC; + + /// Make sure llvm.assume calls that are deleted are removed from the cache. + void deleted() override; + + public: + AssumeHandle(Value *V, AssumptionCache *AC = nullptr) + : CallbackVH(V), AC(AC) {} + + operator Value *() const { return getValPtr(); } + CallInst *getAssumeCI() const { return cast(getValPtr()); } + }; + struct ResultElem { WeakVH Assume; @@ -59,9 +77,9 @@ /// We track this to lazily populate our assumptions. Function &F; - /// Vector of weak value handles to calls of the \@llvm.assume - /// intrinsic. - SmallVector AssumeHandles; + /// Set of value handles for calls of the \@llvm.assume intrinsic. + using AssumeHandleSet = DenseSet>; + AssumeHandleSet AssumeHandles; class AffectedValueCallbackVH final : public CallbackVH { AssumptionCache *AC; @@ -137,13 +155,7 @@ /// Access the list of assumption handles currently tracked for this /// function. - /// - /// Note that these produce weak handles that may be null. The caller must - /// handle that case. - /// FIXME: We should replace this with pointee_iterator> - /// when we can write that to filter out the null values. Then caller code - /// will become simpler. - MutableArrayRef assumptions() { + AssumeHandleSet &assumptions() { if (!Scanned) scanFunction(); return AssumeHandles; diff --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp --- a/llvm/lib/Analysis/AssumptionCache.cpp +++ b/llvm/lib/Analysis/AssumptionCache.cpp @@ -163,7 +163,12 @@ AffectedValues.erase(AVI); } - erase_value(AssumeHandles, CI); + AssumeHandles.erase({CI, this}); +} + +void AssumptionCache::AssumeHandle::deleted() { + AC->AssumeHandles.erase(*this); + // 'this' now dangles! } void AssumptionCache::AffectedValueCallbackVH::deleted() { @@ -204,14 +209,14 @@ for (BasicBlock &B : F) for (Instruction &II : B) if (match(&II, m_Intrinsic())) - AssumeHandles.push_back({&II, ExprResultIdx}); + AssumeHandles.insert({&II, this}); // Mark the scan as complete. Scanned = true; // Update affected values. - for (auto &A : AssumeHandles) - updateAffectedValues(cast(A)); + for (auto &AssumeVH : AssumeHandles) + updateAffectedValues(AssumeVH.getAssumeCI()); } void AssumptionCache::registerAssumption(CallInst *CI) { @@ -223,7 +228,7 @@ if (!Scanned) return; - AssumeHandles.push_back({CI, ExprResultIdx}); + AssumeHandles.insert({CI, this}); #ifndef NDEBUG assert(CI->getParent() && @@ -231,20 +236,11 @@ assert(&F == CI->getParent()->getParent() && "Cannot register @llvm.assume call not in this function"); - // We expect the number of assumptions to be small, so in an asserts build - // check that we don't accumulate duplicates and that all assumptions point - // to the same function. - SmallPtrSet AssumptionSet; - for (auto &VH : AssumeHandles) { - if (!VH) - continue; - - assert(&F == cast(VH)->getParent()->getParent() && + for (auto &AssumeVH : AssumeHandles) { + assert(&F == AssumeVH.getAssumeCI()->getCaller() && "Cached assumption not inside this function!"); - assert(match(cast(VH), m_Intrinsic()) && + assert(match(AssumeVH.getAssumeCI(), m_Intrinsic()) && "Cached something other than a call to @llvm.assume!"); - assert(AssumptionSet.insert(VH).second && - "Cache contains multiple copies of a call!"); } #endif @@ -258,9 +254,8 @@ AssumptionCache &AC = AM.getResult(F); OS << "Cached assumptions for function: " << F.getName() << "\n"; - for (auto &VH : AC.assumptions()) - if (VH) - OS << " " << *cast(VH)->getArgOperand(0) << "\n"; + for (auto &AssumeVH : AC.assumptions()) + OS << " " << *AssumeVH.getAssumeCI()->getArgOperand(0) << "\n"; return PreservedAnalyses::all(); } @@ -306,9 +301,8 @@ SmallPtrSet AssumptionSet; for (const auto &I : AssumptionCaches) { - for (auto &VH : I.second->assumptions()) - if (VH) - AssumptionSet.insert(cast(VH)); + for (auto &AssumeVH : I.second->assumptions()) + AssumptionSet.insert(AssumeVH.getAssumeCI()); for (const BasicBlock &B : cast(*I.first)) for (const Instruction &II : B) diff --git a/llvm/lib/Analysis/CodeMetrics.cpp b/llvm/lib/Analysis/CodeMetrics.cpp --- a/llvm/lib/Analysis/CodeMetrics.cpp +++ b/llvm/lib/Analysis/CodeMetrics.cpp @@ -73,9 +73,7 @@ SmallVector Worklist; for (auto &AssumeVH : AC->assumptions()) { - if (!AssumeVH) - continue; - Instruction *I = cast(AssumeVH); + Instruction *I = AssumeVH.getAssumeCI(); // Filter out call sites outside of the loop so we don't do a function's // worth of work for each of its loops (and, in the common case, ephemeral @@ -97,9 +95,7 @@ SmallVector Worklist; for (auto &AssumeVH : AC->assumptions()) { - if (!AssumeVH) - continue; - Instruction *I = cast(AssumeVH); + Instruction *I = AssumeVH.getAssumeCI(); assert(I->getParent()->getParent() == F && "Found assumption for the wrong function!"); diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -1704,9 +1704,9 @@ getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); } - + // For a negative step, we can extend the operands iff doing so only - // traverses values in the range zext([0,UINT_MAX]). + // traverses values in the range zext([0,UINT_MAX]). if (isKnownNegative(Step)) { const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - getSignedRangeMin(Step)); @@ -9927,9 +9927,7 @@ // Check conditions due to any @llvm.assume intrinsics. for (auto &AssumeVH : AC.assumptions()) { - if (!AssumeVH) - continue; - auto *CI = cast(AssumeVH); + auto *CI = AssumeVH.getAssumeCI(); if (!DT.dominates(CI, Latch->getTerminator())) continue; @@ -10076,9 +10074,7 @@ // Check conditions due to any @llvm.assume intrinsics. for (auto &AssumeVH : AC.assumptions()) { - if (!AssumeVH) - continue; - auto *CI = cast(AssumeVH); + auto *CI = AssumeVH.getAssumeCI(); if (!DT.dominates(CI, BB)) continue; @@ -13358,9 +13354,7 @@ // Also collect information from assumptions dominating the loop. for (auto &AssumeVH : AC.assumptions()) { - if (!AssumeVH) - continue; - auto *AssumeI = cast(AssumeVH); + auto *AssumeI = AssumeVH.getAssumeCI(); auto *Cmp = dyn_cast(AssumeI->getOperand(0)); if (!Cmp || !DT.dominates(AssumeI, L->getHeader())) continue; diff --git a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp --- a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp +++ b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -331,12 +331,11 @@ DT = DT_; bool Changed = false; - for (auto &AssumeVH : AC.assumptions()) - if (AssumeVH) { - CallInst *Call = cast(AssumeVH); - for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++) - Changed |= processAssumption(Call, Idx); - } + for (auto &AssumeVH : AC.assumptions()) { + CallInst *Call = AssumeVH.getAssumeCI(); + for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++) + Changed |= processAssumption(Call, Idx); + } return Changed; } diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -1781,10 +1781,8 @@ bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc, const Function &NewFunc, AssumptionCache *AC) { - for (auto AssumeVH : AC->assumptions()) { - auto *I = dyn_cast_or_null(AssumeVH); - if (!I) - continue; + for (auto &AssumeVH : AC->assumptions()) { + auto *I = AssumeVH.getAssumeCI(); // There shouldn't be any llvm.assume intrinsics in the new function. if (I->getFunction() != &OldFunc) diff --git a/llvm/lib/Transforms/Utils/PredicateInfo.cpp b/llvm/lib/Transforms/Utils/PredicateInfo.cpp --- a/llvm/lib/Transforms/Utils/PredicateInfo.cpp +++ b/llvm/lib/Transforms/Utils/PredicateInfo.cpp @@ -532,10 +532,11 @@ processSwitch(SI, BranchBB, OpsToRename); } } - for (auto &Assume : AC.assumptions()) { - if (auto *II = dyn_cast_or_null(Assume)) - if (DT.isReachableFromEntry(II->getParent())) - processAssume(II, II->getParent(), OpsToRename); + for (auto &AssumeVH : AC.assumptions()) { + CallInst *AssumeCI = AssumeVH.getAssumeCI(); + if (DT.isReachableFromEntry(AssumeCI->getParent())) + processAssume(cast(AssumeCI), AssumeCI->getParent(), + OpsToRename); } // Now rename all our operations. renameUses(OpsToRename); diff --git a/llvm/test/Analysis/AssumptionCache/basic.ll b/llvm/test/Analysis/AssumptionCache/basic.ll --- a/llvm/test/Analysis/AssumptionCache/basic.ll +++ b/llvm/test/Analysis/AssumptionCache/basic.ll @@ -6,9 +6,9 @@ define void @test1(i32 %a) { ; CHECK-LABEL: Cached assumptions for function: test1 -; CHECK-NEXT: icmp ne i32 %{{.*}}, 0 -; CHECK-NEXT: icmp slt i32 %{{.*}}, 0 -; CHECK-NEXT: icmp sgt i32 %{{.*}}, 0 +; CHECK-DAG: icmp ne i32 %{{.*}}, 0 +; CHECK-DAG: icmp slt i32 %{{.*}}, 0 +; CHECK-DAG: icmp sgt i32 %{{.*}}, 0 entry: %cond1 = icmp ne i32 %a, 0