diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -154,6 +154,7 @@ /// are floating values that do not have a corresponding attribute list /// position. struct IRPosition { + using CallBaseContext = CallBase; /// The positions we distinguish in the IR. enum Kind : char { @@ -174,27 +175,34 @@ IRPosition() : Enc(nullptr, ENC_VALUE) { verify(); } /// Create a position describing the value of \p V. - static const IRPosition value(const Value &V) { + static const IRPosition value(const Value &V, + const CallBaseContext *CBContext = nullptr) { if (auto *Arg = dyn_cast(&V)) - return IRPosition::argument(*Arg); + return IRPosition::argument(*Arg, CBContext); if (auto *CB = dyn_cast(&V)) return IRPosition::callsite_returned(*CB); - return IRPosition(const_cast(V), IRP_FLOAT); + return IRPosition(const_cast(V), IRP_FLOAT, CBContext); } /// Create a position describing the function scope of \p F. - static const IRPosition function(const Function &F) { - return IRPosition(const_cast(F), IRP_FUNCTION); + /// \p CBContext is used for call base specific analysis. + static const IRPosition function(const Function &F, + const CallBaseContext *CBContext = nullptr) { + return IRPosition(const_cast(F), IRP_FUNCTION, CBContext); } /// Create a position describing the returned value of \p F. - static const IRPosition returned(const Function &F) { - return IRPosition(const_cast(F), IRP_RETURNED); + /// \p CBContext is used for call base specific analysis. + static const IRPosition returned(const Function &F, + const CallBaseContext *CBContext = nullptr) { + return IRPosition(const_cast(F), IRP_RETURNED, CBContext); } /// Create a position describing the argument \p Arg. - static const IRPosition argument(const Argument &Arg) { - return IRPosition(const_cast(Arg), IRP_ARGUMENT); + /// \p CBContext is used for call base specific analysis. + static const IRPosition argument(const Argument &Arg, + const CallBaseContext *CBContext = nullptr) { + return IRPosition(const_cast(Arg), IRP_ARGUMENT, CBContext); } /// Create a position describing the function scope of \p CB. @@ -230,16 +238,20 @@ /// If \p IRP is a call site (see isAnyCallSitePosition()) then the result /// will be a call site position, otherwise the function position of the /// associated function. - static const IRPosition function_scope(const IRPosition &IRP) { + static const IRPosition + function_scope(const IRPosition &IRP, + const CallBaseContext *CBContext = nullptr) { if (IRP.isAnyCallSitePosition()) { return IRPosition::callsite_function( cast(IRP.getAnchorValue())); } assert(IRP.getAssociatedFunction()); - return IRPosition::function(*IRP.getAssociatedFunction()); + return IRPosition::function(*IRP.getAssociatedFunction(), CBContext); } - bool operator==(const IRPosition &RHS) const { return Enc == RHS.Enc; } + bool operator==(const IRPosition &RHS) const { + return Enc == RHS.Enc && RHS.CBContext == CBContext; + } bool operator!=(const IRPosition &RHS) const { return !(*this == RHS); } /// Return the value this abstract attribute is anchored with. @@ -439,6 +451,19 @@ } } + /// Return the same position without the call base context. + IRPosition stripCallBaseContext() const { + IRPosition Result = *this; + Result.CBContext = nullptr; + return Result; + } + + /// Get the call base context from the position. + const CallBaseContext *getCallBaseContext() const { return CBContext; } + + /// Check if the position has any call base context. + bool hasCallBaseContext() const { return CBContext != nullptr; } + /// Special DenseMap key values. /// ///{ @@ -451,10 +476,15 @@ private: /// Private constructor for special values only! - explicit IRPosition(void *Ptr) { Enc.setFromOpaqueValue(Ptr); } + explicit IRPosition(void *Ptr, const CallBaseContext *CBContext = nullptr) + : CBContext(CBContext) { + Enc.setFromOpaqueValue(Ptr); + } /// IRPosition anchored at \p AnchorVal with kind/argument numbet \p PK. - explicit IRPosition(Value &AnchorVal, Kind PK) { + explicit IRPosition(Value &AnchorVal, Kind PK, + const CallBaseContext *CBContext = nullptr) + : CBContext(CBContext) { switch (PK) { case IRPosition::IRP_INVALID: llvm_unreachable("Cannot create invalid IRP with an anchor value!"); @@ -557,16 +587,27 @@ PointerIntPair Enc; ///} + /// Call base context. Used for callsite specific analysis. + const CallBaseContext *CBContext = nullptr; + /// Return the encoding bits. char getEncodingBits() const { return Enc.getInt(); } }; /// Helper that allows IRPosition as a key in a DenseMap. -template <> struct DenseMapInfo : DenseMapInfo { +template <> struct DenseMapInfo { static inline IRPosition getEmptyKey() { return IRPosition::EmptyKey; } static inline IRPosition getTombstoneKey() { return IRPosition::TombstoneKey; } + static unsigned getHashValue(const IRPosition &IRP) { + return (DenseMapInfo::getHashValue(IRP) << 4) ^ + (DenseMapInfo::getHashValue(IRP.getCallBaseContext())); + } + + static bool isEqual(const IRPosition &a, const IRPosition &b) { + return a == b; + } }; /// A visitor class for IR positions. @@ -877,11 +918,25 @@ /// only the Attributor itself. Initial seeding of AAs can be done via this /// function. template - const AAType &getOrCreateAAFor(const IRPosition &IRP, + const AAType &getOrCreateAAFor(IRPosition IRP, const AbstractAttribute *QueryingAA = nullptr, bool TrackDependence = false, DepClassTy DepClass = DepClassTy::OPTIONAL, bool ForceUpdate = false) { +#ifdef EXPENSIVE_CHECKS + // Don't allow callbase information to leak. + if (auto CBContext = IRP.getCallBaseContext()) { + assert( + ((CBContext->getCalledFunction() == IRP.getAnchorScope() || + QueryingAA || + !QueryingAA.getIRPosition().isAnyCallSitePosition())) && + "non callsite positions are not allowed to propagate CallBaseContext " + "across functions"); + } +#endif + if (!shouldPropagateCallBaseContext(IRP)) + IRP = IRP.stripCallBaseContext(); + if (AAType *AAPtr = lookupAAFor(IRP, QueryingAA, TrackDependence)) { if (ForceUpdate) updateAA(*AAPtr); @@ -1353,6 +1408,9 @@ const AbstractAttribute *QueryingAA, bool &AllCallSitesKnown); + /// Determine if CallBase context in \p IRP should be propagated. + bool shouldPropagateCallBaseContext(const IRPosition &IRP); + /// Apply all requested function signature rewrites /// (\see registerFunctionSignatureRewrite) and return Changed if the module /// was altered. diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp --- a/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/llvm/lib/Transforms/IPO/Attributor.cpp @@ -85,6 +85,11 @@ "allowed to be seeded."), cl::ZeroOrMore, cl::CommaSeparated); +static cl::opt EnableCallSiteSpecific( + "attributor-enable-call-site-specific-deduction", cl::Hidden, + cl::desc("Allow the Attributor to do call site specific analysis"), + cl::init(false)); + /// Logic operators for the change status enum class. /// ///{ @@ -412,6 +417,8 @@ #ifdef EXPENSIVE_CHECKS switch (getPositionKind()) { case IRP_INVALID: + assert((CBContext == nullptr) && + "Invalid position must not have CallBaseContext!"); assert(!Enc.getOpaqueValue() && "Expected a nullptr for an invalid position!"); return; @@ -427,12 +434,16 @@ "Associated value mismatch!"); return; case IRP_CALL_SITE_RETURNED: + assert((CBContext == nullptr) && + "'call site returned' position must not have CallBaseContext!"); assert((isa(getAsValuePtr())) && "Expected call base for 'call site returned' position!"); assert(getAsValuePtr() == &getAssociatedValue() && "Associated value mismatch!"); return; case IRP_CALL_SITE: + assert((CBContext == nullptr) && + "'call site function' position must not have CallBaseContext!"); assert((isa(getAsValuePtr())) && "Expected call base for 'call site function' position!"); assert(getAsValuePtr() == &getAssociatedValue() && @@ -451,6 +462,8 @@ "Associated value mismatch!"); return; case IRP_CALL_SITE_ARGUMENT: { + assert((CBContext == nullptr) && + "'call site argument' position must not have CallBaseContext!"); Use *U = getAsUsePtr(); assert(U && "Expected use for a 'call site argument' position!"); assert(isa(U->getUser()) && @@ -785,6 +798,13 @@ return true; } +bool Attributor::shouldPropagateCallBaseContext(const IRPosition &IRP) { + // TODO: Maintain a cache of Values that are + // on the pathway from a Argument to a Instruction that would effect the + // liveness/return state etc. + return EnableCallSiteSpecific; +} + bool Attributor::checkForAllReturnedValuesAndReturnInsts( function_ref &)> Pred, const AbstractAttribute &QueryingAA) { @@ -1052,6 +1072,9 @@ if (!State.isAtFixpoint()) State.indicateOptimisticFixpoint(); + // We must not manifest Attributes that use Callbase info. + if (AA->hasCallBaseContext()) + continue; // If the state is invalid, we do not try to manifest it. if (!State.isValidState()) continue; @@ -2004,8 +2027,12 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, const IRPosition &Pos) { const Value &AV = Pos.getAssociatedValue(); - return OS << "{" << Pos.getPositionKind() << ":" << AV.getName() << " [" - << Pos.getAnchorValue().getName() << "@" << Pos.getArgNo() << "]}"; + OS << "{" << Pos.getPositionKind() << ":" << AV.getName() << " [" + << Pos.getAnchorValue().getName() << "@" << Pos.getArgNo() << "]"; + + if (Pos.hasCallBaseContext()) + OS << "[cb_context:" << *Pos.getCallBaseContext() << "]"; + return OS << "}"; } raw_ostream &llvm::operator<<(raw_ostream &OS, const IntegerRangeState &S) { diff --git a/llvm/unittests/Transforms/IPO/AttributorTest.cpp b/llvm/unittests/Transforms/IPO/AttributorTest.cpp --- a/llvm/unittests/Transforms/IPO/AttributorTest.cpp +++ b/llvm/unittests/Transforms/IPO/AttributorTest.cpp @@ -21,6 +21,22 @@ namespace llvm { +TEST_F(AttributorTestBase, IRPPositionCallBaseContext) { + const char *ModuleString = R"( + define i32 @foo(i32 %a) { + entry: + ret i32 %a + } + )"; + + parseModule(ModuleString); + + Function *F = M->getFunction("foo"); + IRPosition Pos = IRPosition::function(*F, (const llvm::CallBase *)0xDEADBEEF); + EXPECT_TRUE(Pos.hasCallBaseContext()); + EXPECT_FALSE(Pos.stripCallBaseContext().hasCallBaseContext()); +} + TEST_F(AttributorTestBase, TestCast) { const char *ModuleString = R"( define i32 @foo(i32 %a, i32 %b) {