diff --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h --- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h +++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h @@ -69,6 +69,9 @@ // Just a shorter abbreviation to improve indentation. using Cost = InstructionCost; +// Map of known constants found during the specialization bonus estimation. +using ConstMap = DenseMap; + // Specialization signature, used to uniquely designate a specialization within // a function. struct SpecSig { @@ -151,6 +154,10 @@ bool run(); + /// Compute a bonus for replacing argument \p A with constant \p C. + Cost getSpecializationBonus(Argument *A, Constant *C, + ConstMap &KnownConstants); + private: Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call); @@ -194,8 +201,14 @@ /// Compute and return the cost of specializing function \p F. Cost getSpecializationCost(Function *F); - /// Compute a bonus for replacing argument \p A with constant \p C. - Cost getSpecializationBonus(Argument *A, Constant *C); + Cost estimateSwitchInst(SwitchInst *I, Value *V, ConstantInt *C, + ConstMap &KnownConstants); + + Cost estimateBranchInst(BranchInst *I, Value *V, Constant *C, + ConstMap &KnownConstants); + + Cost getUserBonus(Instruction *User, Value *Use, Constant *C, + ConstMap &KnownConstants); /// Determine if it is possible to specialise the function for constant values /// of the formal parameter \p A. diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -48,7 +48,9 @@ #include "llvm/Transforms/IPO/FunctionSpecialization.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueLattice.h" #include "llvm/Analysis/ValueLatticeUtils.h" @@ -412,10 +414,6 @@ CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues); for (BasicBlock &BB : *F) Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues); - - LLVM_DEBUG(dbgs() << "FnSpecialization: Code size of function " - << F->getName() << " is " << Metrics.NumInsts - << " instructions\n"); } return Metrics; } @@ -496,8 +494,9 @@ } else { // Calculate the specialisation gain. Cost Score = 0 - SpecCost; + DenseMap KnownConstants; for (ArgInfo &A : S.Args) - Score += getSpecializationBonus(A.Formal, A.Actual); + Score += getSpecializationBonus(A.Formal, A.Actual, KnownConstants); // Discard unprofitable specialisations. if (!ForceSpecialization && Score <= 0) @@ -584,49 +583,230 @@ // Otherwise, set the specialization cost to be the cost of all the // instructions in the function. - return Metrics.NumInsts * InlineConstants::getInstrCost(); + return Metrics.NumInsts; } -static Cost getUserBonus(User *U, TargetTransformInfo &TTI, - BlockFrequencyInfo &BFI) { - auto *I = dyn_cast_or_null(U); - // If not an instruction we do not know how to evaluate. - // Keep minimum possible cost for now so that it doesnt affect - // specialization. - if (!I) - return 0; - - uint64_t Weight = BFI.getBlockFreq(I->getParent()).getFrequency() / - BFI.getEntryFreq(); - if (!Weight) - return 0; - - Cost Bonus = Weight * - TTI.getInstructionCost(U, TargetTransformInfo::TCK_SizeAndLatency); - - // Traverse recursively if there are more uses. - // TODO: Any other instructions to be added here? - if (I->mayReadFromMemory() || I->isCast()) - for (auto *User : I->users()) - Bonus += getUserBonus(User, TTI, BFI); +Cost FunctionSpecializer::estimateSwitchInst(SwitchInst *I, Value *V, + ConstantInt *C, + ConstMap &KnownConstants) { + Cost Bonus = 0; + if (I->getCondition() == V) { + Function *F = I->getFunction(); + auto &TTI = (GetTTI)(*F); + auto &BFI = (GetBFI)(*F); + BasicBlock *Succ = I->findCaseValue(C)->getCaseSuccessor(); + + // Initialize the worklist with the dead basic blocks. + SmallVector WorkList; + for (const auto &Case : I->cases()) { + BasicBlock *BB = Case.getCaseSuccessor(); + if (BB == Succ || !BB->hasNPredecessors(1) || + !Solver.isBlockExecutable(BB)) + continue; + WorkList.push_back(BB); + } + + while (!WorkList.empty()) { + BasicBlock *BB = WorkList.pop_back_val(); + + uint64_t Weight = BFI.getBlockFreq(BB).getFrequency() / + BFI.getEntryFreq(); + if (!Weight) + continue; + + for (Instruction &I : *BB) { + if (auto *II = dyn_cast(&I)) + if (II->getIntrinsicID() == Intrinsic::ssa_copy) + continue; + if (KnownConstants.contains(&I)) + continue; + + Bonus += Weight * + TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus + << " after user " << I << "\n"); + } + + // Populate the worklist with dead successors. + for (BasicBlock *SuccBB : successors(BB)) + if (SuccBB->hasNPredecessors(1) && Solver.isBlockExecutable(SuccBB)) + WorkList.push_back(SuccBB); + } + } + return Bonus; +} + +Cost FunctionSpecializer::estimateBranchInst(BranchInst *I, Value *V, + Constant *C, + ConstMap &KnownConstants) { + Cost Bonus = 0; + if (I->getCondition() == V) { + Function *F = I->getFunction(); + auto &TTI = (GetTTI)(*F); + auto &BFI = (GetBFI)(*F); + BasicBlock *Succ = I->getSuccessor(C->isOneValue()); + + // Initialize the worklist with the dead basic block. + SmallVector WorkList; + if (Succ->hasNPredecessors(1) && Solver.isBlockExecutable(Succ)) + WorkList.push_back(Succ); + + while (!WorkList.empty()) { + BasicBlock *BB = WorkList.pop_back_val(); + + uint64_t Weight = BFI.getBlockFreq(BB).getFrequency() / + BFI.getEntryFreq(); + if (!Weight) + continue; + + for (Instruction &I : *BB) { + if (auto *II = dyn_cast(&I)) + if (II->getIntrinsicID() == Intrinsic::ssa_copy) + continue; + if (KnownConstants.contains(&I)) + continue; + + Bonus += Weight * + TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus + << " after user " << I << "\n"); + } + // Populate the worklist with dead successors. + for (BasicBlock *SuccBB : successors(BB)) + if (SuccBB->hasNPredecessors(1) && Solver.isBlockExecutable(SuccBB)) + WorkList.push_back(SuccBB); + } + } return Bonus; } +Cost FunctionSpecializer::getUserBonus(Instruction *User, Value *Use, + Constant *C, ConstMap &KnownConstants) { + KnownConstants.insert({Use, C}); + + if (auto *I = dyn_cast(User)) { + return estimateSwitchInst(I, Use, cast(C), KnownConstants); + } else if (auto *I = dyn_cast(User)) { + return estimateBranchInst(I, Use, C, KnownConstants); + } else if (auto *I = dyn_cast(User)) { + if (isa(C)) + C = nullptr; + else + C = ConstantFoldLoadFromConstPtr(C, I->getType(), M.getDataLayout()); + } else if (auto *I = dyn_cast(User)) { + SmallVector Operands; + Operands.reserve(I->getNumOperands()); + bool AllOperandsAreConst = true; + for (unsigned Idx = 0, E = I->getNumOperands(); Idx != E; ++Idx) { + Value *V = I->getOperand(Idx); + C = dyn_cast(V); + if (!C) + if (auto It = KnownConstants.find(V); It != KnownConstants.end()) + C = It->second; + if (!C) { + AllOperandsAreConst = false; + break; + } + Operands.push_back(C); + } + if (AllOperandsAreConst) { + Constant *Ptr = Operands[0]; + auto Ops = ArrayRef(Operands.begin() + 1, Operands.end()); + C = ConstantExpr::getGetElementPtr(I->getSourceElementType(), Ptr, Ops); + } + } else if (auto *I = dyn_cast(User)) { + if (I->getCondition() == Use) { + Value *V = C->isZeroValue() ? I->getFalseValue() : I->getTrueValue(); + C = dyn_cast(V); + if (!C) + if (auto It = KnownConstants.find(V); It != KnownConstants.end()) + C = It->second; + } else + C = nullptr; + } else if (auto *I = dyn_cast(User)) { + C = ConstantFoldCastOperand(I->getOpcode(), C, I->getType(), + M.getDataLayout()); + } else if (auto *I = dyn_cast(User)) { + bool Swap = I->getOperand(1) == Use; + Value *V = Swap ? I->getOperand(0) : I->getOperand(1); + auto *Other = dyn_cast(V); + if (!Other) + if (auto It = KnownConstants.find(V); It != KnownConstants.end()) + Other = It->second; + if (!Other) + C = nullptr; + else if (Swap) + C = ConstantFoldCompareInstOperands(I->getPredicate(), Other, C, + M.getDataLayout()); + else + C = ConstantFoldCompareInstOperands(I->getPredicate(), C, Other, + M.getDataLayout()); + } else if (Instruction::isUnaryOp(User->getOpcode())) { + C = ConstantFoldUnaryOpOperand(User->getOpcode(), C, M.getDataLayout()); + } else if (Instruction::isBinaryOp(User->getOpcode())) { + bool Swap = User->getOperand(1) == Use; + Value *V = Swap ? User->getOperand(0) : User->getOperand(1); + auto *Other = dyn_cast(V); + if (!Other) + if (auto It = KnownConstants.find(V); It != KnownConstants.end()) + Other = It->second; + if (!Other) + C = nullptr; + else + C = dyn_cast_or_null( + Swap ? simplifyBinOp(User->getOpcode(), Other, C, + SimplifyQuery(M.getDataLayout())) + : simplifyBinOp(User->getOpcode(), C, Other, + SimplifyQuery(M.getDataLayout()))); + } else { + C = nullptr; + } + + if (C) { + KnownConstants.insert({User, C}); + + Function *F = User->getFunction(); + auto &TTI = (GetTTI)(*F); + auto &BFI = (GetBFI)(*F); + + uint64_t Weight = BFI.getBlockFreq(User->getParent()).getFrequency() / + BFI.getEntryFreq(); + if (!Weight) + return 0; + + Cost Bonus = Weight * + TTI.getInstructionCost(User, TargetTransformInfo::TCK_SizeAndLatency); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus + << " for user " << *User << "\n"); + + for (auto *U : User->users()) + if (auto *UI = dyn_cast(U)) + if (Solver.isBlockExecutable(UI->getParent())) + Bonus += getUserBonus(UI, User, C, KnownConstants); + + return Bonus; + } + return 0; +} + /// Compute a bonus for replacing argument \p A with constant \p C. -Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C) { - Function *F = A->getParent(); - auto &TTI = (GetTTI)(*F); - auto &BFI = (GetBFI)(*F); +Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C, + ConstMap &KnownConstants) { LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: " << C->getNameOrAsOperand() << "\n"); Cost TotalCost = 0; - for (auto *U : A->users()) { - TotalCost += getUserBonus(U, TTI, BFI); - LLVM_DEBUG(dbgs() << "FnSpecialization: User cost "; - TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n"); - } + for (auto *U : A->users()) + if (auto *UI = dyn_cast(U)) + if (Solver.isBlockExecutable(UI->getParent())) + TotalCost += getUserBonus(UI, A, C, KnownConstants); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Accumulated user bonus " + << TotalCost << " for argument " << *A << "\n"); // The below heuristic is only concerned with exposing inlining // opportunities via indirect call promotion. If the argument is not a diff --git a/llvm/unittests/Transforms/IPO/CMakeLists.txt b/llvm/unittests/Transforms/IPO/CMakeLists.txt --- a/llvm/unittests/Transforms/IPO/CMakeLists.txt +++ b/llvm/unittests/Transforms/IPO/CMakeLists.txt @@ -12,6 +12,7 @@ LowerTypeTests.cpp WholeProgramDevirt.cpp AttributorTest.cpp + FunctionSpecializationTest.cpp ) set_property(TARGET IPOTests PROPERTY FOLDER "Tests/UnitTests/TransformsTests") diff --git a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp @@ -0,0 +1,224 @@ +//===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Constants.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Transforms/IPO/FunctionSpecialization.h" +#include "llvm/Transforms/Utils/SCCPSolver.h" +#include "gtest/gtest.h" +#include + +namespace llvm { + +class FunctionSpecializationTest : public testing::Test { +protected: + std::unique_ptr Ctx; + std::unique_ptr M; + std::unique_ptr MAM; + std::unique_ptr FAM; + std::unique_ptr Solver; + std::unique_ptr Specializer; + + FunctionSpecializationTest() { + Ctx = std::make_unique(); + MAM = std::make_unique(); + FAM = std::make_unique(); + + FAM->registerPass([&] { return TargetLibraryAnalysis(); }); + FAM->registerPass([&] { return TargetIRAnalysis(); }); + FAM->registerPass([&] { return BlockFrequencyAnalysis(); }); + FAM->registerPass([&] { return BranchProbabilityAnalysis(); }); + FAM->registerPass([&] { return LoopAnalysis(); }); + FAM->registerPass([&] { return AssumptionAnalysis(); }); + FAM->registerPass([&] { return DominatorTreeAnalysis(); }); + FAM->registerPass([&] { return PostDominatorTreeAnalysis(); }); + FAM->registerPass([&] { return ModuleAnalysisManagerFunctionProxy(*MAM); }); + FAM->registerPass([&] { return PassInstrumentationAnalysis(); }); + MAM->registerPass([&] { return FunctionAnalysisManagerModuleProxy(*FAM); }); + MAM->registerPass([&] { return PassInstrumentationAnalysis(); }); + } + + Module &parseModule(const char *ModuleString) { + SMDiagnostic Err; + M = parseAssemblyString(ModuleString, Err, *Ctx); + EXPECT_TRUE(M); + return *M; + } + + FunctionSpecializer &getSpecializerFor(Function *F) { + ModulePassManager MPM; + MPM.run(*M, *MAM); + + auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { + return FAM->getResult(F); + }; + auto GetTTI = [this](Function &F) -> TargetTransformInfo & { + return FAM->getResult(F); + }; + auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & { + return FAM->getResult(F); + }; + auto GetAC = [this](Function &F) -> AssumptionCache & { + return FAM->getResult(F); + }; + auto GetAnalysis = [this](Function &F) -> AnalysisResultsForFn { + DominatorTree &DT = FAM->getResult(F); + return { std::make_unique(F, DT, + FAM->getResult(F)), + &DT, FAM->getCachedResult(F) }; + }; + + Solver = std::make_unique(M->getDataLayout(), GetTLI, *Ctx); + Specializer = std::make_unique( + *Solver, *M, &*FAM, GetBFI, GetTLI, GetTTI, GetAC); + + Solver->addAnalysis(*F, GetAnalysis(*F)); + Solver->markBlockExecutable(&F->front()); + for (Argument &Arg : F->args()) + Solver->markOverdefined(&Arg); + Solver->solveWhileResolvedUndefsIn(*M); + + return *Specializer; + } +}; + +} // namespace llvm + +using namespace llvm; + +TEST_F(FunctionSpecializationTest, SwitchInst) { + const char *ModuleString = R"( + define void @foo(i32 %a, i32 %b, i32 %i) { + entry: + switch i32 %i, label %default + [ i32 1, label %case1 + i32 2, label %case2 ] + case1: + %0 = mul i32 %a, 2 + %1 = sub i32 6, 5 + br label %bb1 + case2: + %2 = and i32 %b, 3 + %3 = sdiv i32 8, 2 + br label %bb2 + bb1: + %4 = add i32 %0, %b + br label %default + bb2: + %5 = or i32 %2, %a + br label %default + default: + ret void + } + )"; + + Module &M = parseModule(ModuleString); + Function *F = M.getFunction("foo"); + FunctionSpecializer &Specializer = getSpecializerFor(F); + + DenseMap KnownConstants; + Cost Bonus = 0; + Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); + + Bonus += Specializer.getSpecializationBonus( + F->getArg(0), One, KnownConstants); + EXPECT_EQ(Bonus, 0); + + Bonus += Specializer.getSpecializationBonus( + F->getArg(1), One, KnownConstants); + EXPECT_EQ(Bonus, 0); + + Bonus += Specializer.getSpecializationBonus( + F->getArg(2), One, KnownConstants); + EXPECT_EQ(Bonus, 0); +} + +TEST_F(FunctionSpecializationTest, BranchInst) { + const char *ModuleString = R"( + define void @foo(i32 %a, i32 %b, i1 %cond) { + entry: + br i1 %cond, label %bb0, label %bb2 + bb0: + %0 = mul i32 %a, 2 + %1 = sub i32 6, 5 + br label %bb1 + bb1: + %2 = add i32 %0, %b + %3 = sdiv i32 8, 2 + br label %bb2 + bb2: + ret void + } + )"; + + Module &M = parseModule(ModuleString); + Function *F = M.getFunction("foo"); + FunctionSpecializer &Specializer = getSpecializerFor(F); + + DenseMap KnownConstants; + Cost Bonus = 0; + Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); + Constant *False = ConstantInt::getFalse(M.getContext()); + + Bonus += Specializer.getSpecializationBonus( + F->getArg(0), One, KnownConstants); + EXPECT_EQ(Bonus, 0); + + Bonus += Specializer.getSpecializationBonus( + F->getArg(1), One, KnownConstants); + EXPECT_EQ(Bonus, 0); + + Bonus += Specializer.getSpecializationBonus( + F->getArg(2), False, KnownConstants); + EXPECT_EQ(Bonus, 0); +} + +TEST_F(FunctionSpecializationTest, Misc) { + const char *ModuleString = R"( + @g = constant [2 x i32] zeroinitializer, align 4 + + define i32 @foo(ptr %a, i8 %b, i1 %cond) { + %sel = select i1 %cond, i8 %b, i8 10 + %cmp = icmp eq i8 %sel, 10 + %ext = zext i1 %cmp to i32 + %gep = getelementptr i32, ptr %a, i32 %ext + %ld = load i32, ptr %gep + ret i32 %ld + } + )"; + + Module &M = parseModule(ModuleString); + Function *F = M.getFunction("foo"); + FunctionSpecializer &Specializer = getSpecializerFor(F); + + DenseMap KnownConstants; + Cost Bonus = 0; + GlobalVariable *GV = M.getGlobalVariable("g"); + Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1); + Constant *True = ConstantInt::getTrue(M.getContext()); + + Bonus += Specializer.getSpecializationBonus( + F->getArg(0), GV, KnownConstants); + EXPECT_EQ(Bonus, 0); + + Bonus += Specializer.getSpecializationBonus( + F->getArg(1), One, KnownConstants); + EXPECT_EQ(Bonus, 0); + + Bonus += Specializer.getSpecializationBonus( + F->getArg(2), True, KnownConstants); + EXPECT_EQ(Bonus, 5); // select + icmp + zext + gep + load +}