diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -39,6 +39,7 @@ typedef unsigned ID; } +class AllocaInst; class AssumptionCache; class BlockFrequencyInfo; class DominatorTree; @@ -340,6 +341,10 @@ /// \returns A value to be added to the inlining threshold. unsigned adjustInliningThreshold(const CallBase *CB) const; + /// \returns The cost of having an Alloca in the caller if not inlined, to be + /// added to the threshold + unsigned getCallerAllocaCost(const CallBase *CB, const AllocaInst *AI) const; + /// \returns Vector bonus in percent. /// /// Vector bonuses: We want to more aggressively inline vector-dense kernels @@ -1683,6 +1688,8 @@ virtual unsigned getInliningThresholdMultiplier() const = 0; virtual unsigned adjustInliningThreshold(const CallBase *CB) = 0; virtual int getInlinerVectorBonusPercent() const = 0; + virtual unsigned getCallerAllocaCost(const CallBase *CB, + const AllocaInst *AI) const = 0; virtual InstructionCost getMemcpyCost(const Instruction *I) = 0; virtual uint64_t getMaxMemIntrinsicInlineSizeThreshold() const = 0; virtual unsigned @@ -2054,6 +2061,10 @@ int getInlinerVectorBonusPercent() const override { return Impl.getInlinerVectorBonusPercent(); } + unsigned getCallerAllocaCost(const CallBase *CB, + const AllocaInst *AI) const override { + return Impl.getCallerAllocaCost(CB, AI); + } InstructionCost getMemcpyCost(const Instruction *I) override { return Impl.getMemcpyCost(I); } diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -70,6 +70,9 @@ unsigned getInliningThresholdMultiplier() const { return 1; } unsigned adjustInliningThreshold(const CallBase *CB) const { return 0; } + unsigned getCallerAllocaCost(const CallBase *CB, const AllocaInst *AI) const { + return 0; + }; int getInlinerVectorBonusPercent() const { return 150; } diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -538,6 +538,9 @@ unsigned getInliningThresholdMultiplier() const { return 1; } unsigned adjustInliningThreshold(const CallBase *CB) { return 0; } + unsigned getCallerAllocaCost(const CallBase *CB, const AllocaInst *AI) const { + return 0; + } int getInlinerVectorBonusPercent() const { return 150; } diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp --- a/llvm/lib/Analysis/InlineCost.cpp +++ b/llvm/lib/Analysis/InlineCost.cpp @@ -717,7 +717,9 @@ void onInitializeSROAArg(AllocaInst *Arg) override { assert(Arg != nullptr && "Should not initialize SROA costs for null value."); - SROAArgCosts[Arg] = 0; + auto SROAArgCost = TTI.getCallerAllocaCost(&CandidateCall, Arg); + SROACostSavings += SROAArgCost; + SROAArgCosts[Arg] = SROAArgCost; } void onAggregateSROAUse(AllocaInst *SROAArg) override { @@ -1191,7 +1193,12 @@ InstrCost); } - void onInitializeSROAArg(AllocaInst *Arg) override { SROACosts[Arg] = 0; } + void onInitializeSROAArg(AllocaInst *Arg) override { + auto SROAArgCost = TTI.getCallerAllocaCost(&CandidateCall, Arg); + SROACosts[Arg] = SROAArgCost; + SROACostSavingOpportunities += SROAArgCost; + } + void onAggregateSROAUse(AllocaInst *Arg) override { SROACosts.find(Arg)->second += InstrCost; SROACostSavingOpportunities += InstrCost; diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -217,6 +217,11 @@ return TTIImpl->adjustInliningThreshold(CB); } +unsigned TargetTransformInfo::getCallerAllocaCost(const CallBase *CB, + const AllocaInst *AI) const { + return TTIImpl->getCallerAllocaCost(CB, AI); +} + int TargetTransformInfo::getInlinerVectorBonusPercent() const { return TTIImpl->getInlinerVectorBonusPercent(); } diff --git a/llvm/unittests/Analysis/InlineCostTest.cpp b/llvm/unittests/Analysis/InlineCostTest.cpp --- a/llvm/unittests/Analysis/InlineCostTest.cpp +++ b/llvm/unittests/Analysis/InlineCostTest.cpp @@ -8,8 +8,10 @@ #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/InlineModelFeatureMaps.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/AsmParser/Parser.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" @@ -18,10 +20,40 @@ namespace { +using namespace llvm; + +CallBase *getCallInFunction(Function *F) { + for (auto &I : instructions(F)) { + if (auto *CB = dyn_cast(&I)) + return CB; + } + return nullptr; +} + +std::optional getInliningCostFeaturesForCall(CallBase &CB) { + ModuleAnalysisManager MAM; + FunctionAnalysisManager FAM; + FAM.registerPass([&] { return TargetIRAnalysis(); }); + FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); }); + FAM.registerPass([&] { return AssumptionAnalysis(); }); + MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); }); + + MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + + ModulePassManager MPM; + MPM.run(*CB.getModule(), MAM); + + auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { + return FAM.getResult(F); + }; + auto &TIR = FAM.getResult(*CB.getFunction()); + + return getInliningCostFeatures(CB, TIR, GetAssumptionCache); +} + // Tests that we can retrieve the CostFeatures without an error TEST(InlineCostTest, CostFeatures) { - using namespace llvm; - const auto *const IR = R"IR( define i32 @f(i32) { ret i32 4 @@ -42,38 +74,80 @@ ASSERT_TRUE(G); // find the call to f in g - CallBase *CB = nullptr; - for (auto &BB : *G) { - for (auto &I : BB) { - if ((CB = dyn_cast(&I))) - break; - } - } + CallBase *CB = getCallInFunction(G); ASSERT_TRUE(CB); - ModuleAnalysisManager MAM; - FunctionAnalysisManager FAM; - FAM.registerPass([&] { return TargetIRAnalysis(); }); - FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); }); - FAM.registerPass([&] { return AssumptionAnalysis(); }); - MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); }); + const auto Features = getInliningCostFeaturesForCall(*CB); - MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); - FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + // Check that the optional is not empty + ASSERT_TRUE(Features); +} - ModulePassManager MPM; - MPM.run(*M, MAM); +// Tests the calculated SROA cost +TEST(InlineCostTest, SROACost) { + using namespace llvm; - auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { - return FAM.getResult(F); - }; - auto &TIR = FAM.getResult(*G); + const auto *const IR = R"IR( +define void @f_savings(ptr %var) { + %load = load i32, ptr %var + %inc = add i32 %load, 1 + store i32 %inc, ptr %var + ret void +} + +define void @g_savings(i32) { + %var = alloca i32 + call void @f_savings(ptr %var) + ret void +} - const auto Features = - llvm::getInliningCostFeatures(*CB, TIR, GetAssumptionCache); +define void @f_losses(ptr %var) { + %load = load i32, ptr %var + %inc = add i32 %load, 1 + store i32 %inc, ptr %var + call void @prevent_sroa(ptr %var) + ret void +} - // Check that the optional is not empty - ASSERT_TRUE(Features); +define void @g_losses(i32) { + %var = alloca i32 + call void @f_losses(ptr %var) + ret void +} + +declare void @prevent_sroa(ptr) +)IR"; + + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString(IR, Err, C); + ASSERT_TRUE(M); + + const int DefaultInstCost = 5; + const int DefaultAllocaCost = 0; + + const char *GName[] = {"g_savings", "g_losses", nullptr}; + const int Savings[] = {2 * DefaultInstCost + DefaultAllocaCost, 0}; + const int Losses[] = {0, 2 * DefaultInstCost + DefaultAllocaCost}; + + for (unsigned i = 0; GName[i]; ++i) { + auto *G = M->getFunction(GName[i]); + ASSERT_TRUE(G); + + // find the call to f in g + CallBase *CB = getCallInFunction(G); + ASSERT_TRUE(CB); + + const auto Features = getInliningCostFeaturesForCall(*CB); + ASSERT_TRUE(Features); + + // Check the predicted SROA cost + auto GetFeature = [&](InlineCostFeatureIndex I) { + return (*Features)[static_cast(I)]; + }; + ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_savings), Savings[i]); + ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_losses), Losses[i]); + } } } // namespace