diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -39,6 +39,7 @@ typedef unsigned ID; } +class AllocaInst; class AssumptionCache; class BlockFrequencyInfo; class DominatorTree; @@ -340,6 +341,10 @@ /// \returns A value to be added to the inlining threshold. unsigned adjustInliningThreshold(const CallBase *CB) const; + /// \returns The cost of having an Alloca in the caller if not inlined, to be + /// added to the threshold + unsigned getCallerAllocaCost(const CallBase *CB, const AllocaInst *AI) const; + /// \returns Vector bonus in percent. /// /// Vector bonuses: We want to more aggressively inline vector-dense kernels @@ -1672,6 +1677,8 @@ virtual unsigned getInliningThresholdMultiplier() const = 0; virtual unsigned adjustInliningThreshold(const CallBase *CB) = 0; virtual int getInlinerVectorBonusPercent() const = 0; + virtual unsigned getCallerAllocaCost(const CallBase *CB, + const AllocaInst *AI) const = 0; virtual InstructionCost getMemcpyCost(const Instruction *I) = 0; virtual unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI, unsigned &JTSize, @@ -2041,6 +2048,10 @@ int getInlinerVectorBonusPercent() const override { return Impl.getInlinerVectorBonusPercent(); } + unsigned getCallerAllocaCost(const CallBase *CB, + const AllocaInst *AI) const override { + return Impl.getCallerAllocaCost(CB, AI); + } InstructionCost getMemcpyCost(const Instruction *I) override { return Impl.getMemcpyCost(I); } diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -70,6 +70,9 @@ unsigned getInliningThresholdMultiplier() const { return 1; } unsigned adjustInliningThreshold(const CallBase *CB) const { return 0; } + unsigned getCallerAllocaCost(const CallBase *CB, const AllocaInst *AI) const { + return 0; + }; int getInlinerVectorBonusPercent() const { return 150; } diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -534,6 +534,9 @@ unsigned getInliningThresholdMultiplier() const { return 1; } unsigned adjustInliningThreshold(const CallBase *CB) { return 0; } + unsigned getCallerAllocaCost(const CallBase *CB, const AllocaInst *AI) const { + return 0; + } int getInlinerVectorBonusPercent() const { return 150; } diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp --- a/llvm/lib/Analysis/InlineCost.cpp +++ b/llvm/lib/Analysis/InlineCost.cpp @@ -717,7 +717,8 @@ void onInitializeSROAArg(AllocaInst *Arg) override { assert(Arg != nullptr && "Should not initialize SROA costs for null value."); - SROAArgCosts[Arg] = 0; + SROACostSavings += SROAArgCosts[Arg] = + TTI.getCallerAllocaCost(&CandidateCall, Arg); } void onAggregateSROAUse(AllocaInst *SROAArg) override { @@ -1191,7 +1192,11 @@ InstrCost); } - void onInitializeSROAArg(AllocaInst *Arg) override { SROACosts[Arg] = 0; } + void onInitializeSROAArg(AllocaInst *Arg) override { + SROACostSavingOpportunities += SROACosts[Arg] = + TTI.getCallerAllocaCost(&CandidateCall, Arg); + } + void onAggregateSROAUse(AllocaInst *Arg) override { SROACosts.find(Arg)->second += InstrCost; SROACostSavingOpportunities += InstrCost; diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -217,6 +217,11 @@ return TTIImpl->adjustInliningThreshold(CB); } +unsigned TargetTransformInfo::getCallerAllocaCost(const CallBase *CB, + const AllocaInst *AI) const { + return TTIImpl->getCallerAllocaCost(CB, AI); +} + int TargetTransformInfo::getInlinerVectorBonusPercent() const { return TTIImpl->getInlinerVectorBonusPercent(); } diff --git a/llvm/unittests/Analysis/InlineCostTest.cpp b/llvm/unittests/Analysis/InlineCostTest.cpp --- a/llvm/unittests/Analysis/InlineCostTest.cpp +++ b/llvm/unittests/Analysis/InlineCostTest.cpp @@ -8,6 +8,7 @@ #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/InlineModelFeatureMaps.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/Instructions.h" @@ -76,4 +77,74 @@ ASSERT_TRUE(Features); } +// Tests the calculated SROA cost +TEST(InlineCostTest, SROACost) { + using namespace llvm; + + const auto *const IR = R"IR( +define void @f(ptr %var) { + %load = load i32, ptr %var + %inc = add i32 %load, 1 + store i32 %inc, ptr %var + ret void +} + +define void @g(i32) { + %var = alloca i32 + call void @f(ptr %var) + ret void +} +)IR"; + + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString(IR, Err, C); + ASSERT_TRUE(M); + + auto *G = M->getFunction("g"); + ASSERT_TRUE(G); + + // find the call to f in g + CallBase *CB = nullptr; + for (auto &BB : *G) { + for (auto &I : BB) { + if ((CB = dyn_cast(&I))) + break; + } + } + ASSERT_TRUE(CB); + + ModuleAnalysisManager MAM; + FunctionAnalysisManager FAM; + FAM.registerPass([&] { return TargetIRAnalysis(); }); + FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); }); + FAM.registerPass([&] { return AssumptionAnalysis(); }); + MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); }); + + MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + + ModulePassManager MPM; + MPM.run(*M, MAM); + + auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { + return FAM.getResult(F); + }; + auto &TIR = FAM.getResult(*G); + + const auto Features = + llvm::getInliningCostFeatures(*CB, TIR, GetAssumptionCache); + + ASSERT_TRUE(Features); + + // Check the predicted SROA cost + auto GetFeature = [&](InlineCostFeatureIndex I) { + return (*Features)[static_cast(I)]; + }; + ASSERT_EQ( + GetFeature(InlineCostFeatureIndex::sroa_savings), + 10); // 2*InstCost which defaults to 5 + AllocaCost which defaults to 0 + ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_losses), 0); +} + } // namespace