diff --git a/llvm/include/llvm/Analysis/AssumptionCache.h b/llvm/include/llvm/Analysis/AssumptionCache.h --- a/llvm/include/llvm/Analysis/AssumptionCache.h +++ b/llvm/include/llvm/Analysis/AssumptionCache.h @@ -39,6 +39,21 @@ /// register any new \@llvm.assume calls that they create. Deletions of /// \@llvm.assume calls do not require special handling. class AssumptionCache { +public: + /// Value of ResultElem::Index indicating that the argument to the call of the + /// llvm.assume. + enum : unsigned { ExprResultIdx = std::numeric_limits::max() }; + + struct ResultElem { + WeakTrackingVH Assume; + + /// contains either ExprResultIdx or the index of the operand bundle + /// containing the knowledge. + unsigned Index; + operator Value *() const { return Assume; } + }; + +private: /// The function for which this cache is handling assumptions. /// /// We track this to lazily populate our assumptions. @@ -46,7 +61,7 @@ /// Vector of weak value handles to calls of the \@llvm.assume /// intrinsic. - SmallVector AssumeHandles; + SmallVector AssumeHandles; class AffectedValueCallbackVH final : public CallbackVH { AssumptionCache *AC; @@ -66,12 +81,12 @@ /// A map of values about which an assumption might be providing /// information to the relevant set of assumptions. using AffectedValuesMap = - DenseMap, + DenseMap, AffectedValueCallbackVH::DMI>; AffectedValuesMap AffectedValues; /// Get the vector of assumptions which affect a value from the cache. - SmallVector &getOrInsertAffectedValues(Value *V); + SmallVector &getOrInsertAffectedValues(Value *V); /// Move affected values in the cache for OV to be affected values for NV. void transferAffectedValuesInCache(Value *OV, Value *NV); @@ -128,20 +143,20 @@ /// FIXME: We should replace this with pointee_iterator> /// when we can write that to filter out the null values. Then caller code /// will become simpler. - MutableArrayRef assumptions() { + MutableArrayRef assumptions() { if (!Scanned) scanFunction(); return AssumeHandles; } /// Access the list of assumptions which affect this value. - MutableArrayRef assumptionsFor(const Value *V) { + MutableArrayRef assumptionsFor(const Value *V) { if (!Scanned) scanFunction(); auto AVI = AffectedValues.find_as(const_cast(V)); if (AVI == AffectedValues.end()) - return MutableArrayRef(); + return MutableArrayRef(); return AVI->second; } @@ -234,6 +249,21 @@ static char ID; // Pass identification, replacement for typeid }; +template<> struct simplify_type { + using SimpleType = Value *; + + static SimpleType getSimplifiedValue(AssumptionCache::ResultElem &Val) { + return Val; + } +}; +template<> struct simplify_type { + using SimpleType = /*const*/ Value *; + + static SimpleType getSimplifiedValue(const AssumptionCache::ResultElem &Val) { + return Val; + } +}; + } // end namespace llvm #endif // LLVM_ANALYSIS_ASSUMPTIONCACHE_H diff --git a/llvm/include/llvm/Transforms/Utils/AssumeBundleBuilder.h b/llvm/include/llvm/Transforms/Utils/AssumeBundleBuilder.h --- a/llvm/include/llvm/Transforms/Utils/AssumeBundleBuilder.h +++ b/llvm/include/llvm/Transforms/Utils/AssumeBundleBuilder.h @@ -22,6 +22,7 @@ namespace llvm { class IntrinsicInst; +class AssumptionCache; /// Build a call to llvm.assume to preserve informations that can be derived /// from the given instruction. @@ -32,7 +33,7 @@ /// Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert /// if before I. This is usually what need to be done to salvage the knowledge /// contained in the instruction I. -void salvageKnowledge(Instruction *I); +void salvageKnowledge(Instruction *I, AssumptionCache *AC = nullptr); /// This pass will try to build an llvm.assume for every instruction in the /// function. Its main purpose is testing. diff --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp --- a/llvm/lib/Analysis/AssumptionCache.cpp +++ b/llvm/lib/Analysis/AssumptionCache.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -41,7 +42,7 @@ cl::desc("Enable verification of assumption cache"), cl::init(false)); -SmallVector & +SmallVector & AssumptionCache::getOrInsertAffectedValues(Value *V) { // Try using find_as first to avoid creating extra value handles just for the // purpose of doing the lookup. @@ -50,32 +51,39 @@ return AVI->second; auto AVIP = AffectedValues.insert( - {AffectedValueCallbackVH(V, this), SmallVector()}); + {AffectedValueCallbackVH(V, this), SmallVector()}); return AVIP.first->second; } -static void findAffectedValues(CallInst *CI, - SmallVectorImpl &Affected) { +static void +findAffectedValues(CallInst *CI, + SmallVectorImpl &Affected) { // Note: This code must be kept in-sync with the code in // computeKnownBitsFromAssume in ValueTracking. - auto AddAffected = [&Affected](Value *V) { + auto AddAffected = [&Affected](Value *V, unsigned Idx = + AssumptionCache::ExprResultIdx) { if (isa(V)) { - Affected.push_back(V); + Affected.push_back({V, Idx}); } else if (auto *I = dyn_cast(V)) { - Affected.push_back(I); + Affected.push_back({I, Idx}); // Peek through unary operators to find the source of the condition. Value *Op; if (match(I, m_BitCast(m_Value(Op))) || - match(I, m_PtrToInt(m_Value(Op))) || - match(I, m_Not(m_Value(Op)))) { + match(I, m_PtrToInt(m_Value(Op))) || match(I, m_Not(m_Value(Op)))) { if (isa(Op) || isa(Op)) - Affected.push_back(Op); + Affected.push_back({Op, Idx}); } } }; + for (unsigned Idx = 0; Idx != CI->getNumOperandBundles(); Idx++) { + if (CI->getOperandBundleAt(Idx).Inputs.size() > ABA_WasOn && + CI->getOperandBundleAt(Idx).getTagName() != "ignore") + AddAffected(CI->getOperandBundleAt(Idx).Inputs[ABA_WasOn], Idx); + } + Value *Cond = CI->getArgOperand(0), *A, *B; AddAffected(Cond); @@ -112,28 +120,44 @@ } void AssumptionCache::updateAffectedValues(CallInst *CI) { - SmallVector Affected; + SmallVector Affected; findAffectedValues(CI, Affected); for (auto &AV : Affected) { - auto &AVV = getOrInsertAffectedValues(AV); - if (std::find(AVV.begin(), AVV.end(), CI) == AVV.end()) - AVV.push_back(CI); + auto &AVV = getOrInsertAffectedValues(AV.Assume); + if (std::find_if(AVV.begin(), AVV.end(), [&](ResultElem &Elem) { + return Elem.Assume == CI && Elem.Index == AV.Index; + }) == AVV.end()) + AVV.push_back({CI, AV.Index}); } } void AssumptionCache::unregisterAssumption(CallInst *CI) { - SmallVector Affected; + SmallVector Affected; findAffectedValues(CI, Affected); for (auto &AV : Affected) { - auto AVI = AffectedValues.find_as(AV); - if (AVI != AffectedValues.end()) - AffectedValues.erase(AVI); + auto AVI = AffectedValues.find_as(AV.Assume); + if (AVI != AffectedValues.end()) { + bool Found = false; + bool HasNonnull = false; + for (ResultElem &Elem : AVI->second) { + if (Elem.Assume == CI) { + Found = true; + Elem.Assume = nullptr; + } + HasNonnull |= !!Elem.Assume; + if (HasNonnull && Found) + break; + } + assert(Found); + if (!HasNonnull) + AffectedValues.erase(AVI); + } } AssumeHandles.erase( - remove_if(AssumeHandles, [CI](WeakTrackingVH &VH) { return CI == VH; }), + remove_if(AssumeHandles, [CI](ResultElem &RE) { return CI == RE; }), AssumeHandles.end()); } @@ -177,7 +201,7 @@ for (BasicBlock &B : F) for (Instruction &II : B) if (match(&II, m_Intrinsic())) - AssumeHandles.push_back(&II); + AssumeHandles.push_back({&II, ExprResultIdx}); // Mark the scan as complete. Scanned = true; @@ -196,7 +220,7 @@ if (!Scanned) return; - AssumeHandles.push_back(CI); + AssumeHandles.push_back({CI, ExprResultIdx}); #ifndef NDEBUG assert(CI->getParent() && diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -948,7 +948,7 @@ continue; } - salvageKnowledge(&Inst); + salvageKnowledge(&Inst, &AC); salvageDebugInfoOrMarkUndef(Inst); removeMSSA(Inst); Inst.eraseFromParent(); @@ -1015,7 +1015,7 @@ cast(KnownCond)->isOne()) { LLVM_DEBUG(dbgs() << "EarlyCSE removing guard: " << Inst << '\n'); - salvageKnowledge(&Inst); + salvageKnowledge(&Inst, &AC); removeMSSA(Inst); Inst.eraseFromParent(); Changed = true; @@ -1051,7 +1051,7 @@ Changed = true; } if (isInstructionTriviallyDead(&Inst, &TLI)) { - salvageKnowledge(&Inst); + salvageKnowledge(&Inst, &AC); removeMSSA(Inst); Inst.eraseFromParent(); Changed = true; @@ -1077,7 +1077,7 @@ if (auto *I = dyn_cast(V)) I->andIRFlags(&Inst); Inst.replaceAllUsesWith(V); - salvageKnowledge(&Inst); + salvageKnowledge(&Inst, &AC); removeMSSA(Inst); Inst.eraseFromParent(); Changed = true; @@ -1138,7 +1138,7 @@ } if (!Inst.use_empty()) Inst.replaceAllUsesWith(Op); - salvageKnowledge(&Inst); + salvageKnowledge(&Inst, &AC); removeMSSA(Inst); Inst.eraseFromParent(); Changed = true; @@ -1182,7 +1182,7 @@ } if (!Inst.use_empty()) Inst.replaceAllUsesWith(InVal.first); - salvageKnowledge(&Inst); + salvageKnowledge(&Inst, &AC); removeMSSA(Inst); Inst.eraseFromParent(); Changed = true; @@ -1235,7 +1235,7 @@ LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); continue; } - salvageKnowledge(&Inst); + salvageKnowledge(&Inst, &AC); removeMSSA(Inst); Inst.eraseFromParent(); Changed = true; @@ -1271,7 +1271,7 @@ if (!DebugCounter::shouldExecute(CSECounter)) { LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); } else { - salvageKnowledge(&Inst); + salvageKnowledge(&Inst, &AC); removeMSSA(*LastStore); LastStore->eraseFromParent(); Changed = true; diff --git a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp --- a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp +++ b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp @@ -8,6 +8,7 @@ #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" #include "llvm/Analysis/AssumeBundleQueries.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/ADT/DenseSet.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" @@ -222,9 +223,12 @@ return Builder.build(); } -void llvm::salvageKnowledge(Instruction *I) { - if (Instruction *Intr = buildAssumeFromInst(I)) +void llvm::salvageKnowledge(Instruction *I, AssumptionCache *AC) { + if (IntrinsicInst *Intr = buildAssumeFromInst(I)) { Intr->insertBefore(I); + if (AC) + AC->registerAssumption(Intr); + } } PreservedAnalyses AssumeBuilderPass::run(Function &F, diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp --- a/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -1838,9 +1838,11 @@ // check what will be known at the start of the inlined code. AddAlignmentAssumptions(CS, IFI); + AssumptionCache *AC = + IFI.GetAssumptionCache ? &(*IFI.GetAssumptionCache)(*Caller) : nullptr; + /// Preserve all attributes on of the call and its parameters. - if (Instruction *Assume = buildAssumeFromInst(CS.getInstruction())) - Assume->insertBefore(CS.getInstruction()); + salvageKnowledge(CS.getInstruction(), AC); // We want the inliner to prune the code as it copies. We would LOVE to // have no dead or constant instructions leftover after inlining occurs diff --git a/llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp b/llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp --- a/llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp +++ b/llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/CallSite.h" @@ -510,3 +511,66 @@ // large. RunRandTest(9876789, 100000, -0, 7, 100); } + +TEST(AssumeQueryAPI, AssumptionCache) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString( + "declare void @llvm.assume(i1)\n" + "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3, i1 %B) {\n" + "call void @llvm.assume(i1 true) [\"nonnull\"(i32* %P), \"align\"(i32* " + "%P2, i32 4), \"align\"(i32* %P, i32 8)]\n" + "call void @llvm.assume(i1 %B) [\"test\"(i32* %P1), " + "\"dereferenceable\"(i32* %P, i32 4)]\n" + "ret void\n}\n", + Err, C); + if (!Mod) + Err.print("AssumeQueryAPI", errs()); + Function *F = Mod->getFunction("test"); + BasicBlock::iterator First = F->begin()->begin(); + BasicBlock::iterator Second = F->begin()->begin(); + Second++; + AssumptionCacheTracker ACT; + AssumptionCache &AC = ACT.getAssumptionCache(*F); + auto AR = AC.assumptionsFor(F->getArg(3)); + ASSERT_EQ(AR.size(), 0u); + AR = AC.assumptionsFor(F->getArg(1)); + ASSERT_EQ(AR.size(), 1u); + ASSERT_EQ(AR[0].Index, 0u); + ASSERT_EQ(AR[0].Assume, &*Second); + AR = AC.assumptionsFor(F->getArg(2)); + ASSERT_EQ(AR.size(), 1u); + ASSERT_EQ(AR[0].Index, 1u); + ASSERT_EQ(AR[0].Assume, &*First); + AR = AC.assumptionsFor(F->getArg(0)); + ASSERT_EQ(AR.size(), 3u); + llvm::sort(AR, + [](const auto &L, const auto &R) { return L.Index < R.Index; }); + ASSERT_EQ(AR[0].Assume, &*First); + ASSERT_EQ(AR[0].Index, 0u); + ASSERT_EQ(AR[1].Assume, &*Second); + ASSERT_EQ(AR[1].Index, 1u); + ASSERT_EQ(AR[2].Assume, &*First); + ASSERT_EQ(AR[2].Index, 2u); + AR = AC.assumptionsFor(F->getArg(4)); + ASSERT_EQ(AR.size(), 1u); + ASSERT_EQ(AR[0].Assume, &*Second); + ASSERT_EQ(AR[0].Index, AssumptionCache::ExprResultIdx); + AC.unregisterAssumption(cast(&*Second)); + AR = AC.assumptionsFor(F->getArg(1)); + ASSERT_EQ(AR.size(), 0u); + AR = AC.assumptionsFor(F->getArg(0)); + ASSERT_EQ(AR.size(), 3u); + llvm::sort(AR, + [](const auto &L, const auto &R) { return L.Index < R.Index; }); + ASSERT_EQ(AR[0].Assume, &*First); + ASSERT_EQ(AR[0].Index, 0u); + ASSERT_EQ(AR[1].Assume, nullptr); + ASSERT_EQ(AR[1].Index, 1u); + ASSERT_EQ(AR[2].Assume, &*First); + ASSERT_EQ(AR[2].Index, 2u); + AR = AC.assumptionsFor(F->getArg(2)); + ASSERT_EQ(AR.size(), 1u); + ASSERT_EQ(AR[0].Index, 1u); + ASSERT_EQ(AR[0].Assume, &*First); +}