diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -9524,8 +9524,30 @@ bool CanReachUnknownCallee = false; }; -struct AAHotColdFunction : public AAHotCold { - AAHotColdFunction(const IRPosition &IRP, Attributor &A) : AAHotCold(IRP, A) {} +struct AAHotColdImpl : public AAHotCold { + AAHotColdImpl(const IRPosition &IRP, Attributor &A) : AAHotCold(IRP, A) {} + + const std::string getAsStr() const override { + std::string State; + if (IsHot.hasValue()) + State = IsHot.getValue() ? "Hot" : "Cold"; + else + State = "Unknown"; + + return "HotCold[" + State + "]"; + } + + void trackStatistics() const override {} + + Optional isAssumedHot() const override { return IsHot; } + + /// true if the function is hot. false if it cold and None if we don't know. + Optional IsHot = llvm::None; +}; + +struct AAHotColdFunction : public AAHotColdImpl { + AAHotColdFunction(const IRPosition &IRP, Attributor &A) + : AAHotColdImpl(IRP, A) {} /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { @@ -9548,23 +9570,23 @@ Optional OldIsHot = IsHot; // If all of the callers of a function are cold, we can assume that the - // function is cold too. But we can't assume it is hot. - // TODO: Propagate accross calls inside the same block. + // function is cold too. But we can't assume it is hot if all callers are + // hot. auto CallSiteCheck = [&](AbstractCallSite ACS) { - Function *Caller = ACS.getInstruction()->getFunction(); + CallBase *CB = ACS.getInstruction(); const AAHotCold &AA = A.getAAFor( - *this, IRPosition::function(*Caller), DepClassTy::REQUIRED); + *this, IRPosition::callsite_function(*CB), DepClassTy::REQUIRED); Optional AAAssumedHot = AA.isAssumedHot(); if (!AAAssumedHot.hasValue()) return true; - // We can only distribute cold this way. - if (!IsHot.hasValue() && !AAAssumedHot.getValue()) + if (!IsHot.hasValue()) IsHot = AAAssumedHot; else if (IsHot != AAAssumedHot) // Conflicting hot cold input from callers. - // give up. + // example: one is hot another one is cold. + // we give up in this situation. return false; return true; }; @@ -9575,10 +9597,62 @@ if (!A.checkForAllCallSites(CallSiteCheck, *this, true, AllCallSitesKnown)) { IsHot = llvm::None; - indicatePessimisticFixpoint(); } - return OldIsHot == IsHot ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; + ChangeStatus Change = + OldIsHot == IsHot ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; + + // if there is a cold function being called from a basic block, + // we can assume that, any other call site inside that block is cold also. + // if all call sites are cold, we can assume that the function is cold. + // Here we collect information about call sites and get a hot/cold state + // for a basic block. + auto CheckCall = [&](Instruction &Inst) { + BasicBlock *ParentBlock = Inst.getParent(); + + CallBase &CB = static_cast(Inst); + Function *Fn = CB.getCalledFunction(); + Optional CallSiteInfo = llvm::None; + + if (Fn) { + // There is no need to use the call site info here. + const auto &AA = A.getAAFor(*this, IRPosition::function(*Fn), + DepClassTy::REQUIRED); + CallSiteInfo = AA.isAssumedHot(); + // llvm::None means we don't know. + // we can skip this call. + if (!CallSiteInfo.hasValue()) + return true; + } + // If this block contains a call with a unknown callee, assume that we + // can't know. We are being conservative here. + + bool IsNew = BlockStatus.count(ParentBlock); + Optional &CurrentState = BlockStatus[ParentBlock]; + Optional OldState = CurrentState; + + // If this is the first instruction we are visiting from this block. + // we should just take the call sites state. + if (IsNew) + CurrentState = CallSiteInfo; + else if (CurrentState != CallSiteInfo) + // If we have a conflict, assume that we can't know + CurrentState = llvm::None; + + // Track changes. + Change |= CurrentState == OldState ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + + return true; + }; + + bool UsedAssumed = false; + if (!A.checkForAllCallLikeInstructions(CheckCall, *this, UsedAssumed)) { + // If we haven't looked at all instructions we can't trust this data. + BlockStatus.clear(); + } + + return Change; } ChangeStatus manifest(Attributor &A) override { @@ -9592,24 +9666,49 @@ return IRAttributeManifest::manifestAttrs(A, this->getIRPosition(), Attrs); } - const std::string getAsStr() const override { - std::string State; - if (IsHot.hasValue()) - State = IsHot.getValue() ? "Hot" : "Cold"; - else - State = "Unknown"; + /// This represents the information we have about a basic block. + /// we use this information to determine if a callsite inside the + /// block is hot or cold. + /// A block not being in here means that we don't know if it is hot or cold. + /// A llvm::None value means that we now that we can't know for sure. + DenseMap> BlockStatus; +}; - return "HotCold[" + State + "]"; - } +struct AAHotColdCallSite : public AAHotColdImpl { + AAHotColdCallSite(const IRPosition &IRP, Attributor &A) + : AAHotColdImpl(IRP, A) {} - void trackStatistics() const override {} + ChangeStatus updateImpl(Attributor &A) override { + Optional OldIsHot = IsHot; - Optional isAssumedHot() const override { return IsHot; } + auto CheckChange = [&]() { + return OldIsHot == IsHot ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + }; - /// true if the function is hot. false if it cold and None if we don't know. - Optional IsHot = llvm::None; -}; + Function *Anchor = getAnchorScope(); + auto &AnchorAA = A.getAAFor(*this, IRPosition::function(*Anchor), + DepClassTy::REQUIRED); + Optional AnchorIsHot = AnchorAA.isAssumedHot(); + // We can only propate cold this way. + if (AnchorIsHot && !AnchorIsHot.getValue()) { + IsHot = false; + return CheckChange(); + } + + auto AnchorFunctionAA = static_cast(AnchorAA); + + BasicBlock *ParentBlock = getCtxI()->getParent(); + if (AnchorFunctionAA.BlockStatus.count(ParentBlock)) { + Optional BlockStatus = AnchorFunctionAA.BlockStatus[ParentBlock]; + if (BlockStatus.hasValue()) + IsHot = BlockStatus; + } + + return CheckChange(); + } +}; } // namespace AACallGraphNode *AACallEdgeIterator::operator*() const { @@ -9747,6 +9846,7 @@ CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoReturn) CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAReturnedValues) CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryLocation) +CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAHotCold) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonNull) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoAlias) @@ -9768,7 +9868,6 @@ CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUndefinedBehavior) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AACallEdges) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAFunctionReachability) -CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAHotCold) CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryBehavior) diff --git a/llvm/test/Transforms/Attributor/hotcold.ll b/llvm/test/Transforms/Attributor/hotcold.ll --- a/llvm/test/Transforms/Attributor/hotcold.ll +++ b/llvm/test/Transforms/Attributor/hotcold.ll @@ -60,10 +60,30 @@ ret void } +; This function get's marked as cold because it is called from external, which contains +; calls to other cold functions. +define internal void @test_cold_block() { +; IS__TUNIT____: Function Attrs: cold nofree nosync nounwind willreturn writeonly +; IS__TUNIT____-LABEL: define {{[^@]+}}@test_cold_block +; IS__TUNIT____-SAME: () #[[ATTR0]] { +; IS__TUNIT____-NEXT: store i32 1, i32* @G, align 4 +; IS__TUNIT____-NEXT: ret void +; +; IS__CGSCC____: Function Attrs: cold nofree norecurse nosync nounwind willreturn writeonly +; IS__CGSCC____-LABEL: define {{[^@]+}}@test_cold_block +; IS__CGSCC____-SAME: () #[[ATTR0]] { +; IS__CGSCC____-NEXT: store i32 1, i32* @G, align 4 +; IS__CGSCC____-NEXT: ret void +; + store i32 1, i32* @G, align 4 + ret void +} + define void @external() { ; IS__TUNIT____: Function Attrs: nofree nosync nounwind willreturn writeonly ; IS__TUNIT____-LABEL: define {{[^@]+}}@external ; IS__TUNIT____-SAME: () #[[ATTR1]] { +; IS__TUNIT____-NEXT: call void @cold_block() #[[ATTR1]] ; IS__TUNIT____-NEXT: call void @cold_caller2() #[[ATTR1]] ; IS__TUNIT____-NEXT: call void @cold_caller1() #[[ATTR1]] ; IS__TUNIT____-NEXT: ret void @@ -71,10 +91,12 @@ ; IS__CGSCC____: Function Attrs: nofree norecurse nosync nounwind willreturn writeonly ; IS__CGSCC____-LABEL: define {{[^@]+}}@external ; IS__CGSCC____-SAME: () #[[ATTR1:[0-9]+]] { +; IS__CGSCC____-NEXT: call void @cold_block() #[[ATTR2]] ; IS__CGSCC____-NEXT: call void @cold_caller2() #[[ATTR2]] ; IS__CGSCC____-NEXT: call void @cold_caller1() #[[ATTR2]] ; IS__CGSCC____-NEXT: ret void ; + call void @cold_block_test() call void @cold_caller2() call void @cold_caller1() ret void