diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -1427,8 +1427,16 @@ const AbstractAttribute *QueryingAA, bool &AllCallSitesKnown); + /// for function \p Fn Populate the cache for values that the call base + /// context is going to get propagated to. + void populateArgumentPathwayForFunction(Function *Fn); + /// Determine if CallBase context in \p IRP should be propagated. - bool shouldPropagateCallBaseContext(const IRPosition &IRP); + /// \p IgnoreCommandLine is for ignoring the + /// `attributor-enable-call-site-specific-deduction` command line argument, + /// only used for testing. + bool shouldPropagateCallBaseContext(const IRPosition &IRP, + bool IgnoreCommandLine = false); /// Apply all requested function signature rewrites /// (\see registerFunctionSignatureRewrite) and return Changed if the module @@ -1516,6 +1524,11 @@ SmallPtrSet ToBeDeletedBlocks; SmallDenseSet ToBeDeletedInsts; ///} + + /// Values per function that are relevant to the call base context. + DenseMap> ArgumentPathway; + + friend class AttributorTestBase; }; /// An interface to query the internal state of an abstract attribute. diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp --- a/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/llvm/lib/Transforms/IPO/Attributor.cpp @@ -90,6 +90,13 @@ cl::desc("Allow the Attributor to do call site specific analysis"), cl::init(false)); +static cl::opt + CallSiteSpecificDepth("attributor-call-site-specifc-deduction-depth", + cl::Hidden, + cl::desc("Maximum depth of values that call site " + "specific analysis context will reach."), + cl::init(6)); + /// Logic operators for the change status enum class. /// ///{ @@ -807,11 +814,104 @@ return true; } -bool Attributor::shouldPropagateCallBaseContext(const IRPosition &IRP) { - // TODO: Maintain a cache of Values that are - // on the pathway from a Argument to a Instruction that would effect the - // liveness/return state etc. - return EnableCallSiteSpecific; +void Attributor::populateArgumentPathwayForFunction(Function *Fn) { + /// Efficiently compute the intersections between Values that can be + /// reached forwards from the arguments and the Values that can be + /// reached backwards from the instructions that would benefit. + /// from call site specific information. + assert(Fn != nullptr && "There must be a function!"); + + SmallVector SuccessfullValues; + SmallVector Worklist( + pointer_iterator(Fn->arg_begin()), + pointer_iterator(Fn->arg_end())); + SmallPtrSet ForwardsVisited; + + for (unsigned I = 0; I < CallSiteSpecificDepth && !Worklist.empty(); I++) { + SmallVector NewWorklist; + while (!Worklist.empty()) { + Value *Val = Worklist.pop_back_val(); + + // Avoid infinite loops. + if (!ForwardsVisited.insert(Val).second) + continue; + + if (auto *Instr = dyn_cast(Val)) { + // Keep track of important instructions. + if (Instr->isTerminator()) + SuccessfullValues.push_back(Instr); + } + for (User *U : Val->users()) + NewWorklist.push_back(U); + } + Worklist.swap(NewWorklist); + } + + Worklist.assign(SuccessfullValues.begin(), SuccessfullValues.end()); + + SmallSet BackwardsVisited; + for (unsigned I = 0; I < CallSiteSpecificDepth && !Worklist.empty(); I++) { + SmallVector NewWorklist; + while (!Worklist.empty()) { + Value *Val = Worklist.pop_back_val(); + + // Avoid infinite loops. + if (!BackwardsVisited.insert(Val).second) + continue; + + if (auto Inst = dyn_cast(Val)) { + for (Use &U : Inst->operands()) { + + // Avoid visiting nodes that can not be reached by the arguments. + // this creates the intersection. + if (!ForwardsVisited.count(U.get())) + continue; + NewWorklist.push_back(U.get()); + } + } + } + Worklist.swap(NewWorklist); + } + + LLVM_DEBUG(dbgs() << "[Attributor] Argument pathway for function: " + << Fn->getName() << "\n"); + for (Value *V : BackwardsVisited) { + LLVM_DEBUG(dbgs() << "[Attributor]" << *V << " is in Argument Pathway." + << "\n"); + } + ArgumentPathway[Fn] = std::move(BackwardsVisited); +} + +bool Attributor::shouldPropagateCallBaseContext(const IRPosition &IRP, + bool IgnoreCommandLine) { + if (!IgnoreCommandLine && !EnableCallSiteSpecific) + return false; + + // Certain positions must always propogate call base context. + switch (IRP.getPositionKind()) { + case IRPosition::IRP_RETURNED: + case IRPosition::IRP_ARGUMENT: + case IRPosition::IRP_FUNCTION: + return true; + case IRPosition::IRP_FLOAT: + break; + default: + return false; + } + Function *Fn = IRP.getAssociatedFunction(); + if (!Fn) + return false; + + // Look up the value to see if it is one of the values that will recieve + // call base context. + auto ArgPathwayIt = ArgumentPathway.find(Fn); + if (ArgPathwayIt == ArgumentPathway.end()) { + // No cached information, create the `patway`. + populateArgumentPathwayForFunction(Fn); + ArgPathwayIt = ArgumentPathway.find(Fn); + } + + return ((*ArgPathwayIt).second).count(&IRP.getAnchorValue()); } bool Attributor::checkForAllReturnedValuesAndReturnInsts( diff --git a/llvm/unittests/Transforms/IPO/AttributorTest.cpp b/llvm/unittests/Transforms/IPO/AttributorTest.cpp --- a/llvm/unittests/Transforms/IPO/AttributorTest.cpp +++ b/llvm/unittests/Transforms/IPO/AttributorTest.cpp @@ -9,14 +9,9 @@ #include "llvm/Transforms/IPO/Attributor.h" #include "AttributorTestBase.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Analysis/CGSCCPassManager.h" -#include "llvm/Analysis/CallGraphSCCPass.h" -#include "llvm/Analysis/LoopAnalysisManager.h" -#include "llvm/AsmParser/Parser.h" -#include "llvm/Support/Allocator.h" -#include "llvm/Testing/Support/Error.h" -#include "llvm/Transforms/Utils/CallGraphUpdater.h" +#include "llvm/Support/ErrorHandling.h" #include "gtest/gtest.h" +#include #include namespace llvm { @@ -37,6 +32,63 @@ EXPECT_FALSE(Pos.stripCallBaseContext().hasCallBaseContext()); } +TEST_F(AttributorTestBase, CallBaseContextPathway) { + const char *ModuleString = R"( + define i32 @foo(i32 %a, i32 %b) #0 { + ; Distance to return is 4 (should have cb context) + %a1 = add nsw i32 %a, 1 + %a2 = add nsw i32 %a1, 1 + %a3 = add nsw i32 %a2, 1 + + ; Distance to return is 7 (should not have cb context) + %b1 = add nsw i32 %b, 1 + %b2 = add nsw i32 %b1, 1 + %b3 = add nsw i32 %b2, 1 + %b4 = add nsw i32 %b3, 1 + %b5 = add nsw i32 %b4, 1 + %b6 = add nsw i32 %b5, 1 + %b7 = add nsw i32 %b6, 1 + + %c = add nsw i32 %b7, %a3 + ret i32 %c + })"; + + auto Instr = [](Function *F, StringRef Name) -> Instruction * { + for (inst_iterator I = inst_begin(F); I != inst_end(F); I++) { + if (I->getName() == Name) + return &*I; + } + llvm_unreachable("Instruction name not found!"); + return nullptr; + }; + + Module &M = parseModule(ModuleString); + Attributor &A = createAttributor(M); + Function *F = M.getFunction("foo"); + +#define EXPECT_PROPAGATE(n) \ + EXPECT_TRUE(doesPropagateCBContext(A, *Instr(F, (n)))) +#define EXPECT_NOT_PROPAGATE(n) \ + EXPECT_FALSE(doesPropagateCBContext(A, *Instr(F, (n)))) + + EXPECT_PROPAGATE("a1"); + EXPECT_PROPAGATE("a2"); + EXPECT_PROPAGATE("a3"); + + EXPECT_NOT_PROPAGATE("b1"); + EXPECT_NOT_PROPAGATE("b2"); + EXPECT_NOT_PROPAGATE("b3"); + EXPECT_NOT_PROPAGATE("b4"); + EXPECT_NOT_PROPAGATE("b5"); + EXPECT_NOT_PROPAGATE("b6"); + EXPECT_NOT_PROPAGATE("b7"); + + EXPECT_PROPAGATE("c"); + +#undef EXPECT_PROPAGATE +#undef EXPECT_NOT_PROPAGATE +} + TEST_F(AttributorTestBase, TestCast) { const char *ModuleString = R"( define i32 @foo(i32 %a, i32 %b) { @@ -47,16 +99,7 @@ )"; Module &M = parseModule(ModuleString); - - SetVector Functions; - AnalysisGetter AG; - for (Function &F : M) - Functions.insert(&F); - - CallGraphUpdater CGUpdater; - BumpPtrAllocator Allocator; - InformationCache InfoCache(M, AG, Allocator, nullptr); - Attributor A(Functions, InfoCache, CGUpdater); + Attributor &A = createAttributor(M); Function *F = M.getFunction("foo"); @@ -72,4 +115,4 @@ ASSERT_TRUE(SSucc); } -} // namespace llvm \ No newline at end of file +} // namespace llvm diff --git a/llvm/unittests/Transforms/IPO/AttributorTestBase.h b/llvm/unittests/Transforms/IPO/AttributorTestBase.h --- a/llvm/unittests/Transforms/IPO/AttributorTestBase.h +++ b/llvm/unittests/Transforms/IPO/AttributorTestBase.h @@ -12,6 +12,7 @@ #ifndef LLVM_UNITTESTS_TRANSFORMS_ATTRIBUTOR_TESTBASE_H #define LLVM_UNITTESTS_TRANSFORMS_ATTRIBUTOR_TESTBASE_H +#include "llvm/ADT/iterator.h" #include "llvm/Analysis/CGSCCPassManager.h" #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/AsmParser/Parser.h" @@ -31,6 +32,7 @@ protected: std::unique_ptr Ctx; std::unique_ptr M; + std::unique_ptr A; AttributorTestBase() : Ctx(new LLVMContext) {} @@ -40,8 +42,26 @@ EXPECT_TRUE(M); return *M; } + + Attributor &createAttributor(Module &M) { + SetVector Functions( + pointer_iterator(M.begin()), + pointer_iterator(M.end())); + + AnalysisGetter AG; + CallGraphUpdater CGUpdater; + BumpPtrAllocator Allocator; + InformationCache InfoCache(M, AG, Allocator, nullptr); + A = std::make_unique(Functions, InfoCache, CGUpdater); + return *A; + } + + bool doesPropagateCBContext(Attributor &A, Instruction &Inst) { + return A.shouldPropagateCallBaseContext(IRPosition::value(Inst), + true /* IgnoreCommandLine */); + } }; } // namespace llvm -#endif \ No newline at end of file +#endif