diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -168,6 +168,7 @@ void initializeForwardControlFlowIntegrityPass(PassRegistry&); void initializeFuncletLayoutPass(PassRegistry&); void initializeFunctionImportLegacyPassPass(PassRegistry&); +void initializeFunctionSpecializationLegacyPassPass(PassRegistry &); void initializeGCMachineCodeAnalysisPass(PassRegistry&); void initializeGCModuleInfoPass(PassRegistry&); void initializeGCOVProfilerLegacyPassPass(PassRegistry&); diff --git a/llvm/include/llvm/LinkAllPasses.h b/llvm/include/llvm/LinkAllPasses.h --- a/llvm/include/llvm/LinkAllPasses.h +++ b/llvm/include/llvm/LinkAllPasses.h @@ -231,6 +231,7 @@ (void) llvm::createInjectTLIMappingsLegacyPass(); (void) llvm::createUnifyLoopExitsPass(); (void) llvm::createFixIrreduciblePass(); + (void)llvm::createFunctionSpecializationPass(); (void)new llvm::IntervalPartition(); (void)new llvm::ScalarEvolutionWrapperPass(); diff --git a/llvm/include/llvm/Transforms/IPO.h b/llvm/include/llvm/Transforms/IPO.h --- a/llvm/include/llvm/Transforms/IPO.h +++ b/llvm/include/llvm/Transforms/IPO.h @@ -169,6 +169,11 @@ /// ModulePass *createIPSCCPPass(); +//===----------------------------------------------------------------------===// +/// createFunctionSpecializationPass - This pass propagates constants from call +/// sites to the specialized version of the callee function. +ModulePass *createFunctionSpecializationPass(); + //===----------------------------------------------------------------------===// // /// createLoopExtractorPass - This pass extracts all natural loops from the diff --git a/llvm/include/llvm/Transforms/IPO/SCCP.h b/llvm/include/llvm/Transforms/IPO/SCCP.h --- a/llvm/include/llvm/Transforms/IPO/SCCP.h +++ b/llvm/include/llvm/Transforms/IPO/SCCP.h @@ -32,6 +32,14 @@ PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); }; +/// Pass to perform interprocedural constant propagation by specializing +/// functions +class FunctionSpecializationPass + : public PassInfoMixin { +public: + PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); +}; + } // end namespace llvm #endif // LLVM_TRANSFORMS_IPO_SCCP_H diff --git a/llvm/include/llvm/Transforms/Scalar/SCCP.h b/llvm/include/llvm/Transforms/Scalar/SCCP.h --- a/llvm/include/llvm/Transforms/Scalar/SCCP.h +++ b/llvm/include/llvm/Transforms/Scalar/SCCP.h @@ -22,6 +22,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/Module.h" @@ -42,6 +43,13 @@ bool runIPSCCP(Module &M, const DataLayout &DL, std::function GetTLI, function_ref getAnalysis); + +bool runFunctionSpecialization( + Module &M, const DataLayout &DL, + std::function GetTLI, + std::function GetTTI, + std::function GetAC, + function_ref GetAnalysis); } // end namespace llvm #endif // LLVM_TRANSFORMS_SCALAR_SCCP_H diff --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h --- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h +++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h @@ -130,6 +130,23 @@ /// Helper to return a Constant if \p LV is either a constant or a constant /// range with a single element. Constant *getConstant(const ValueLatticeElement &LV) const; + + /// Return a reference to the set of argument tracked functions. + SmallPtrSetImpl &getArgumentTrackedFunctions(); + + /// Mark argument \p A constant with value \p C in a new function + /// specialization. The argument's parent function is a specialization of the + /// original function \p F. All other arguments of the specialization inherit + /// the lattice state of their corresponding values in the original function. + void markArgInFuncSpecialization(Function *F, Argument *A, Constant *C); + + /// Mark all of the blocks in function \p F non-executable. Clients can used + /// this method to erase a function from the module (e.g., if it has been + /// completely specialized and is no longer needed). + void markFunctionUnreachable(Function *F); + + void visit(Instruction *I); + void visitCall(CallInst &I); }; } // namespace llvm diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -294,6 +294,7 @@ namespace llvm { extern cl::opt MaxDevirtIterations; extern cl::opt EnableConstraintElimination; +extern cl::opt EnableFunctionSpecialization; extern cl::opt EnableGVNHoist; extern cl::opt EnableGVNSink; extern cl::opt EnableHotColdSplit; @@ -1131,6 +1132,10 @@ for (auto &C : PipelineEarlySimplificationEPCallbacks) C(MPM, Level); + // Specialize functions with IPSCCP. + if (EnableFunctionSpecialization) + MPM.addPass(FunctionSpecializationPass()); + // Interprocedural constant propagation now that basic cleanup has occurred // and prior to optimizing globals. // FIXME: This position in the pipeline hasn't been carefully considered in @@ -1698,6 +1703,9 @@ // produce the same result as if we only do promotion here. MPM.addPass(PGOIndirectCallPromotion( true /* InLTO */, PGOOpt && PGOOpt->Action == PGOOptions::SampleUse)); + + if (EnableFunctionSpecialization) + MPM.addPass(FunctionSpecializationPass()); // Propagate constants at call sites into the functions they call. This // opens opportunities for globalopt (and inlining) by substituting function // pointers passed as arguments to direct uses of functions. diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -55,6 +55,7 @@ MODULE_PASS("extract-blocks", BlockExtractorPass()) MODULE_PASS("forceattrs", ForceFunctionAttrsPass()) MODULE_PASS("function-import", FunctionImportPass()) +MODULE_PASS("function-specialization", FunctionSpecializationPass()) MODULE_PASS("globaldce", GlobalDCEPass()) MODULE_PASS("globalopt", GlobalOptPass()) MODULE_PASS("globalsplit", GlobalSplitPass()) diff --git a/llvm/lib/Transforms/IPO/IPO.cpp b/llvm/lib/Transforms/IPO/IPO.cpp --- a/llvm/lib/Transforms/IPO/IPO.cpp +++ b/llvm/lib/Transforms/IPO/IPO.cpp @@ -32,6 +32,7 @@ initializeDAEPass(Registry); initializeDAHPass(Registry); initializeForceFunctionAttrsLegacyPassPass(Registry); + initializeFunctionSpecializationLegacyPassPass(Registry); initializeGlobalDCELegacyPassPass(Registry); initializeGlobalOptLegacyPassPass(Registry); initializeGlobalSplitPass(Registry); diff --git a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp --- a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -27,6 +27,7 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Verifier.h" +#include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Target/CGPassBuilderOption.h" @@ -43,6 +44,7 @@ #include "llvm/Transforms/Scalar/InstSimplifyPass.h" #include "llvm/Transforms/Scalar/LICM.h" #include "llvm/Transforms/Scalar/LoopUnrollPass.h" +#include "llvm/Transforms/Scalar/SCCP.h" #include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Vectorize.h" @@ -166,6 +168,10 @@ cl::desc( "Enable pass to eliminate conditions based on linear constraints.")); +cl::opt EnableFunctionSpecialization( + "enable-function-specialization", cl::init(false), cl::Hidden, + cl::desc("Enable Function Specialization pass")); + cl::opt AttributorRun( "attributor-enable", cl::Hidden, cl::init(AttributorRunOption::NONE), cl::desc("Enable the attributor inter-procedural deduction pass."), @@ -739,6 +745,10 @@ if (OptLevel > 2) MPM.add(createCallSiteSplittingPass()); + // Propage constant function arguments by specializing the functions. + if (OptLevel > 2 && EnableFunctionSpecialization) + MPM.add(createFunctionSpecializationPass()); + MPM.add(createIPSCCPPass()); // IP SCCP MPM.add(createCalledValuePropagationPass()); @@ -994,6 +1004,10 @@ PM.add( createPGOIndirectCallPromotionLegacyPass(true, !PGOSampleUse.empty())); + // Propage constant function arguments by specializing the functions. + if (EnableFunctionSpecialization) + PM.add(createFunctionSpecializationPass()); + // Propagate constants at call sites into the functions they call. This // opens opportunities for globalopt (and inlining) by substituting function // pointers passed as arguments to direct uses of functions. diff --git a/llvm/lib/Transforms/IPO/SCCP.cpp b/llvm/lib/Transforms/IPO/SCCP.cpp --- a/llvm/lib/Transforms/IPO/SCCP.cpp +++ b/llvm/lib/Transforms/IPO/SCCP.cpp @@ -14,6 +14,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/InitializePasses.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Scalar/SCCP.h" @@ -103,3 +104,92 @@ // createIPSCCPPass - This is the public interface to this file. ModulePass *llvm::createIPSCCPPass() { return new IPSCCPLegacyPass(); } + +PreservedAnalyses FunctionSpecializationPass::run(Module &M, + ModuleAnalysisManager &AM) { + const DataLayout &DL = M.getDataLayout(); + auto &FAM = AM.getResult(M).getManager(); + auto GetTLI = [&FAM](Function &F) -> TargetLibraryInfo & { + return FAM.getResult(F); + }; + auto GetTTI = [&FAM](Function &F) -> TargetTransformInfo & { + return FAM.getResult(F); + }; + auto GetAC = [&FAM](Function &F) -> AssumptionCache & { + return FAM.getResult(F); + }; + auto GetAnalysis = [&FAM](Function &F) -> AnalysisResultsForFn { + DominatorTree &DT = FAM.getResult(F); + return {std::make_unique( + F, DT, FAM.getResult(F)), + &DT, FAM.getCachedResult(F)}; + }; + + if (!runFunctionSpecialization(M, DL, GetTLI, GetTTI, GetAC, GetAnalysis)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserve(); + PA.preserve(); + PA.preserve(); + return PA; +} + +struct FunctionSpecializationLegacyPass : public ModulePass { + static char ID; // Pass identification, replacement for typeid + FunctionSpecializationLegacyPass() : ModulePass(ID) {} + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + } + + virtual bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + + const DataLayout &DL = M.getDataLayout(); + auto GetTLI = [this](Function &F) -> TargetLibraryInfo & { + return this->getAnalysis().getTLI(F); + }; + auto GetTTI = [this](Function &F) -> TargetTransformInfo & { + return this->getAnalysis().getTTI(F); + }; + auto GetAC = [this](Function &F) -> AssumptionCache & { + return this->getAnalysis().getAssumptionCache(F); + }; + + auto GetAnalysis = [this](Function &F) -> AnalysisResultsForFn { + DominatorTree &DT = + this->getAnalysis(F).getDomTree(); + return { + std::make_unique( + F, DT, + this->getAnalysis().getAssumptionCache( + F)), + nullptr, // We cannot preserve the DT or PDT with the legacy pass + nullptr}; // manager, so set them to nullptr. + }; + return runFunctionSpecialization(M, DL, GetTLI, GetTTI, GetAC, GetAnalysis); + } +}; + +char FunctionSpecializationLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN( + FunctionSpecializationLegacyPass, "function-specialization", + "Propagate constant arguments by specializing the function", false, false) + +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(FunctionSpecializationLegacyPass, "function-specialization", + "Propagate constant arguments by specializing the function", + false, false) + +ModulePass *llvm::createFunctionSpecializationPass() { + return new FunctionSpecializationLegacyPass(); +} diff --git a/llvm/lib/Transforms/Scalar/CMakeLists.txt b/llvm/lib/Transforms/Scalar/CMakeLists.txt --- a/llvm/lib/Transforms/Scalar/CMakeLists.txt +++ b/llvm/lib/Transforms/Scalar/CMakeLists.txt @@ -13,6 +13,7 @@ EarlyCSE.cpp FlattenCFGPass.cpp Float2Int.cpp + FunctionSpecialization.cpp GuardWidening.cpp GVN.cpp GVNHoist.cpp diff --git a/llvm/lib/Transforms/Scalar/FunctionSpecialization.cpp b/llvm/lib/Transforms/Scalar/FunctionSpecialization.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Transforms/Scalar/FunctionSpecialization.cpp @@ -0,0 +1,637 @@ +//===- FunctionSpecialization.cpp - Function Specialization ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This specialises functions with constant parameters (e.g. functions, +// globals). Constant parameters like function pointers and constant globals +// are propagated to the callee by specializing the function. +// +// Current limitations: +// - It does not handle specialization of recursive functions, +// - It does not yet handle integer constants, and integer ranges, +// - Only 1 argument per function is specialised, +// - The cost-model could be further looked into, +// - We are not yet caching analysis results. +// +// Ideas: +// - With a function specialization attribute for arguments, we could have +// a direct way to steer function specialization, avoiding the cost-model, +// and thus control compile-times / code-size. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Scalar/SCCP.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/SizeOpts.h" + +using namespace llvm; + +#define DEBUG_TYPE "function-specialization" + +STATISTIC(NumFuncSpecialized, "Number of Functions Specialized"); + +static cl::opt ForceFunctionSpecialization( + "force-function-specialization", cl::init(false), cl::Hidden, + cl::desc("Force function specialization for every call site with a " + "constant argument")); + +static cl::opt FuncSpecializationMaxIters( + "func-specialization-max-iters", cl::Hidden, + cl::desc("The maximum number of iterations function specialization is run"), + cl::init(1)); + +static cl::opt MaxConstantsThreshold( + "func-specialization-max-constants", cl::Hidden, + cl::desc("The maximum number of clones allowed for a single function " + "specialization"), + cl::init(3)); + +static cl::opt + AvgLoopIterationCount("func-specialization-avg-iters-cost", cl::Hidden, + cl::desc("Average loop iteration count cost"), + cl::init(10)); + +// Helper to check if \p LV is either overdefined or a constant int. +static bool isOverdefined(const ValueLatticeElement &LV) { + return !LV.isUnknownOrUndef() && !LV.isConstant(); +} + +class FunctionSpecializer { + + /// The IPSCCP Solver. + SCCPSolver &Solver; + + /// Analyses used to help determine if a function should be specialized. + std::function GetAC; + std::function GetTTI; + std::function GetTLI; + + SmallPtrSet SpecializedFuncs; + +public: + FunctionSpecializer(SCCPSolver &Solver, + std::function GetAC, + std::function GetTTI, + std::function GetTLI) + : Solver(Solver), GetAC(GetAC), GetTTI(GetTTI), GetTLI(GetTLI) {} + + /// Attempt to specialize functions in the module to enable constant + /// propagation across function boundaries. + /// + /// \returns true if at least one function is specialized. + bool + specializeFunctions(SmallVectorImpl &FuncDecls, + SmallVectorImpl &CurrentSpecializations) { + + // Attempt to specialize the argument-tracked functions. + bool Changed = false; + for (auto *F : FuncDecls) { + if (specializeFunction(F, CurrentSpecializations)) { + Changed = true; + LLVM_DEBUG(dbgs() << "FnSpecialization: Can specialize this func.\n"); + } else { + LLVM_DEBUG( + dbgs() << "FnSpecialization: Cannot specialize this func.\n"); + } + } + + for (auto *SpecializedFunc : CurrentSpecializations) { + SpecializedFuncs.insert(SpecializedFunc); + + // TODO: If we want to support specializing specialized functions, + // initialize here the state of the newly created functions, marking + // them argument-tracked and executable. + + // Replace the function arguments for the specialized functions. + for (Argument &Arg : SpecializedFunc->args()) + if (!Arg.use_empty() && tryToReplaceWithConstant(&Arg)) + LLVM_DEBUG(dbgs() << "FnSpecialization: Replaced constant argument: " + << Arg.getName() << "\n"); + } + return Changed; + } + + bool tryToReplaceWithConstant(Value *V) { + if (!V->getType()->isSingleValueType() || isa(V) || + V->user_empty()) + return false; + + const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); + if (isOverdefined(IV)) + return false; + auto *Const = IV.isConstant() ? Solver.getConstant(IV) + : UndefValue::get(V->getType()); + V->replaceAllUsesWith(Const); + + // TODO: Update the solver here if we want to specialize specialized + // functions. + return true; + } + +private: + /// This function decides whether to specialize function \p F based on the + /// known constant values its arguments can take on. Specialization is + /// performed on the first interesting argument. Specializations based on + /// additional arguments will be evaluated on following iterations of the + /// main IPSCCP solve loop. \returns true if the function is specialized and + /// false otherwise. + bool specializeFunction(Function *F, + SmallVectorImpl &Specializations) { + + // Do not specialize the cloned function again. + if (SpecializedFuncs.contains(F)) { + return false; + } + + // If we're optimizing the function for size, we shouldn't specialize it. + if (F->hasOptSize() || + shouldOptimizeForSize(F, nullptr, nullptr, PGSOQueryType::IRPass)) + return false; + + // Exit if the function is not executable. There's no point in specializing + // a dead function. + if (!Solver.isBlockExecutable(&F->getEntryBlock())) + return false; + + LLVM_DEBUG(dbgs() << "FnSpecialization: Try function: " << F->getName() + << "\n"); + // Determine if we should specialize the function based on the values the + // argument can take on. If specialization is not profitable, we continue + // on to the next argument. + for (Argument &A : F->args()) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing arg: " << A.getName() + << "\n"); + // True if this will be a partial specialization. We will need to keep + // the original function around in addition to the added specializations. + bool IsPartial = true; + + // Determine if this argument is interesting. If we know the argument can + // take on any constant values, they are collected in Constants. If the + // argument can only ever equal a constant value in Constants, the + // function will be completely specialized, and the IsPartial flag will + // be set to false by isArgumentInteresting (that function only adds + // values to the Constants list that are deemed profitable). + SmallVector Constants; + if (!isArgumentInteresting(&A, Constants, IsPartial)) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Argument is not interesting\n"); + continue; + } + + assert(!Constants.empty() && "No constants on which to specialize"); + LLVM_DEBUG(dbgs() << "FnSpecialization: Argument is interesting!\n" + << "FnSpecialization: Specializing '" << F->getName() + << "' on argument: " << A << "\n" + << "FnSpecialization: Constants are:\n\n"; + for (unsigned I = 0; I < Constants.size(); ++I) dbgs() + << *Constants[I] << "\n"; + dbgs() << "FnSpecialization: End of constants\n\n"); + + // Create a version of the function in which the argument is marked + // constant with the given value. + for (auto *C : Constants) { + // Clone the function. We leave the ValueToValueMap empty to allow + // IPSCCP to propagate the constant arguments. + ValueToValueMapTy EmptyMap; + Function *Clone = CloneFunction(F, EmptyMap); + Argument *ClonedArg = Clone->arg_begin() + A.getArgNo(); + + // Rewrite calls to the function so that they call the clone instead. + rewriteCallSites(F, Clone, *ClonedArg, C); + + // Initialize the lattice state of the arguments of the function clone, + // marking the argument on which we specialized the function constant + // with the given value. + Solver.markArgInFuncSpecialization(F, ClonedArg, C); + + // Mark all the specialized functions + Specializations.push_back(Clone); + NumFuncSpecialized++; + } + + // TODO: if we want to support specialize specialized functions, and if + // the function has been completely specialized, the original function is + // no longer needed, so we would need to mark it unreachable here. + + // FIXME: Only one argument per function. + return true; + } + + return false; + } + + /// Compute the cost of specializing function \p F. + InstructionCost getSpecializationCost(Function *F) { + // Compute the code metrics for the function. + SmallPtrSet EphValues; + CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues); + CodeMetrics Metrics; + for (BasicBlock &BB : *F) + Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues); + + // If the code metrics reveal that we shouldn't duplicate the function, we + // shouldn't specialize it. Set the specialization cost to the maximum. + if (Metrics.notDuplicatable) + return std::numeric_limits::max(); + + // Otherwise, set the specialization cost to be the cost of all the + // instructions in the function and penalty for specializing more functions. + unsigned Penalty = (NumFuncSpecialized + 1); + return Metrics.NumInsts * InlineConstants::InstrCost * Penalty; + } + + InstructionCost getUserBonus(User *U, llvm::TargetTransformInfo &TTI, + LoopInfo &LI) { + auto *I = dyn_cast_or_null(U); + // If not an instruction we do not know how to evaluate. + // Keep minimum possible cost for now so that it doesnt affect + // specialization. + if (!I) + return std::numeric_limits::min(); + + auto Cost = TTI.getUserCost(U, TargetTransformInfo::TCK_SizeAndLatency); + + // Traverse recursively if there are more uses. + // TODO: Any other instructions to be added here? + if (I->mayReadFromMemory() || I->isCast()) + for (auto *User : I->users()) + Cost += getUserBonus(User, TTI, LI); + + // Increase the cost if it is inside the loop. + auto LoopDepth = LI.getLoopDepth(I->getParent()) + 1; + Cost *= (AvgLoopIterationCount ^ LoopDepth); + return Cost; + } + + /// Compute a bonus for replacing argument \p A with constant \p C. + InstructionCost getSpecializationBonus(Argument *A, Constant *C) { + Function *F = A->getParent(); + DominatorTree DT(*F); + LoopInfo LI(DT); + auto &TTI = (GetTTI)(*F); + LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for: " << *A + << "\n"); + + InstructionCost TotalCost = 0; + for (auto *U : A->users()) { + TotalCost += getUserBonus(U, TTI, LI); + LLVM_DEBUG(dbgs() << "FnSpecialization: User cost "; + TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n"); + } + + // The below heuristic is only concerned with exposing inlining + // opportunities via indirect call promotion. If the argument is not a + // function pointer, give up. + if (!isa(A->getType()) || + !isa(A->getType()->getPointerElementType())) + return TotalCost; + + // Since the argument is a function pointer, its incoming constant values + // should be functions or constant expressions. The code below attempts to + // look through cast expressions to find the function that will be called. + Value *CalledValue = C; + while (isa(CalledValue) && + cast(CalledValue)->isCast()) + CalledValue = cast(CalledValue)->getOperand(0); + Function *CalledFunction = dyn_cast(CalledValue); + if (!CalledFunction) + return TotalCost; + + // Get TTI for the called function (used for the inline cost). + auto &CalleeTTI = (GetTTI)(*CalledFunction); + + // Look at all the call sites whose called value is the argument. + // Specializing the function on the argument would allow these indirect + // calls to be promoted to direct calls. If the indirect call promotion + // would likely enable the called function to be inlined, specializing is a + // good idea. + int Bonus = 0; + for (User *U : A->users()) { + if (!isa(U) && !isa(U)) + continue; + auto *CS = cast(U); + if (CS->getCalledOperand() != A) + continue; + + // Get the cost of inlining the called function at this call site. Note + // that this is only an estimate. The called function may eventually + // change in a way that leads to it not being inlined here, even though + // inlining looks profitable now. For example, one of its called + // functions may be inlined into it, making the called function too large + // to be inlined into this call site. + // + // We apply a boost for performing indirect call promotion by increasing + // the default threshold by the threshold for indirect calls. + auto Params = getInlineParams(); + Params.DefaultThreshold += InlineConstants::IndirectCallThreshold; + InlineCost IC = + getInlineCost(*CS, CalledFunction, Params, CalleeTTI, GetAC, GetTLI); + + // We clamp the bonus for this call to be between zero and the default + // threshold. + if (IC.isAlways()) + Bonus += Params.DefaultThreshold; + else if (IC.isVariable() && IC.getCostDelta() > 0) + Bonus += IC.getCostDelta(); + } + + return TotalCost + Bonus; + } + + /// Determine if we should specialize a function based on the incoming values + /// of the given argument. + /// + /// This function implements the goal-directed heuristic. It determines if + /// specializing the function based on the incoming values of argument \p A + /// would result in any significant optimization opportunities. If + /// optimization opportunities exist, the constant values of \p A on which to + /// specialize the function are collected in \p Constants. If the values in + /// \p Constants represent the complete set of values that \p A can take on, + /// the function will be completely specialized, and the \p IsPartial flag is + /// set to false. + /// + /// \returns true if the function should be specialized on the given + /// argument. + bool isArgumentInteresting(Argument *A, + SmallVectorImpl &Constants, + bool &IsPartial) { + Function *F = A->getParent(); + + // For now, don't attempt to specialize functions based on the values of + // composite types. + if (!A->getType()->isSingleValueType() || A->user_empty()) + return false; + + // If the argument isn't overdefined, there's nothing to do. It should + // already be constant. + if (!Solver.getLatticeValueFor(A).isOverdefined()) { + LLVM_DEBUG(dbgs() << "FnSpecialization: nothing to do, arg is already " + << "constant?\n"); + return false; + } + + // Collect the constant values that the argument can take on. If the + // argument can't take on any constant values, we aren't going to + // specialize the function. While it's possible to specialize the function + // based on non-constant arguments, there's likely not much benefit to + // constant propagation in doing so. + // + // TODO 1: currently it won't specialize if there are over the threshold of + // calls using the same argument, e.g foo(a) x 4 and foo(b) x 1, but it + // might be beneficial to take the occurrences into account in the cost + // model, so we would need to find the unique constants. + // + // TODO 2: this currently does not support constants, i.e. integer ranges. + // + SmallVector PossibleConstants; + bool AllConstant = getPossibleConstants(A, PossibleConstants); + if (PossibleConstants.empty()) { + LLVM_DEBUG(dbgs() << "FnSpecialization: no possible constants found\n"); + return false; + } + if (PossibleConstants.size() > MaxConstantsThreshold) { + LLVM_DEBUG(dbgs() << "FnSpecialization: number of constants found exceed " + << "the maximum number of constants threshold.\n"); + return false; + } + + // Determine if it would be profitable to create a specialization of the + // function where the argument takes on the given constant value. If so, + // add the constant to Constants. + auto FnSpecCost = getSpecializationCost(F); + LLVM_DEBUG(dbgs() << "FnSpecialization: func specialisation cost: "; + FnSpecCost.print(dbgs()); dbgs() << "\n"); + + for (auto *C : PossibleConstants) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Constant: " << *C << "\n"); + if (ForceFunctionSpecialization) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Forced!\n"); + Constants.push_back(C); + continue; + } + if (getSpecializationBonus(A, C) > FnSpecCost) { + LLVM_DEBUG(dbgs() << "FnSpecialization: profitable!\n"); + Constants.push_back(C); + } else { + LLVM_DEBUG(dbgs() << "FnSpecialization: not profitable\n"); + } + } + + // None of the constant values the argument can take on were deemed good + // candidates on which to specialize the function. + if (Constants.empty()) + return false; + + // This will be a partial specialization if some of the constants were + // rejected due to their profitability. + IsPartial = !AllConstant || PossibleConstants.size() != Constants.size(); + + return true; + } + + /// Collect in \p Constants all the constant values that argument \p A can + /// take on. + /// + /// \returns true if all of the values the argument can take on are constant + /// (e.g., the argument's parent function cannot be called with an + /// overdefined value). + bool getPossibleConstants(Argument *A, + SmallVectorImpl &Constants) { + Function *F = A->getParent(); + bool AllConstant = true; + + // Iterate over all the call sites of the argument's parent function. + for (User *U : F->users()) { + if (!isa(U) && !isa(U)) + continue; + auto &CS = *cast(U); + + // If the parent of the call site will never be executed, we don't need + // to worry about the passed value. + if (!Solver.isBlockExecutable(CS.getParent())) + continue; + + auto *V = CS.getArgOperand(A->getArgNo()); + // TrackValueOfGlobalVariable only tracks scalar global variables. + if (auto *GV = dyn_cast(V)) { + if (!GV->getValueType()->isSingleValueType()) { + return false; + } + } + + // Get the lattice value for the value the call site passes to the + // argument. If this value is not constant, move on to the next call + // site. Additionally, set the AllConstant flag to false. + if (V != A && !Solver.getLatticeValueFor(V).isConstant()) { + AllConstant = false; + continue; + } + + // Add the constant to the set. + if (auto *C = dyn_cast(CS.getArgOperand(A->getArgNo()))) + Constants.push_back(C); + } + + // If the argument can only take on constant values, AllConstant will be + // true. + return AllConstant; + } + + /// Rewrite calls to function \p F to call function \p Clone instead. + /// + /// This function modifies calls to function \p F whose argument at index \p + /// ArgNo is equal to constant \p C. The calls are rewritten to call function + /// \p Clone instead. + void rewriteCallSites(Function *F, Function *Clone, Argument &Arg, + Constant *C) { + unsigned ArgNo = Arg.getArgNo(); + SmallVector CallSitesToRewrite; + for (auto *U : F->users()) { + if (!isa(U) && !isa(U)) + continue; + auto &CS = *cast(U); + if (!CS.getCalledFunction() || CS.getCalledFunction() != F) + continue; + CallSitesToRewrite.push_back(&CS); + } + for (auto *CS : CallSitesToRewrite) { + + if ((CS->getFunction() == Clone && CS->getArgOperand(ArgNo) == &Arg) || + CS->getArgOperand(ArgNo) == C) { + CS->setCalledFunction(Clone); + Solver.markOverdefined(CS); + } + } + } +}; + +/// Function to clean up the left over intrinsics from SCCP util. +static void cleanup(Module &M) { + for (Function &F : M) { + for (BasicBlock &BB : F) { + for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) { + Instruction *Inst = &*BI++; + if (auto *II = dyn_cast(Inst)) { + if (II->getIntrinsicID() == Intrinsic::ssa_copy) { + Value *Op = II->getOperand(0); + Inst->replaceAllUsesWith(Op); + Inst->eraseFromParent(); + } + } + } + } + } +} + +bool llvm::runFunctionSpecialization( + Module &M, const DataLayout &DL, + std::function GetTLI, + std::function GetTTI, + std::function GetAC, + function_ref GetAnalysis) { + SCCPSolver Solver(DL, GetTLI, M.getContext()); + FunctionSpecializer FS(Solver, GetAC, GetTTI, GetTLI); + bool Changed = false; + + // Loop over all functions, marking arguments to those with their addresses + // taken or that are external as overdefined. + for (Function &F : M) { + if (F.isDeclaration()) + continue; + + LLVM_DEBUG(dbgs() << "\nFnSpecialization: Analysing decl: " << F.getName() + << "\n"); + Solver.addAnalysis(F, GetAnalysis(F)); + + // Determine if we can track the function's arguments. If so, add the + // function to the solver's set of argument-tracked functions. + if (canTrackArgumentsInterprocedurally(&F)) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Can track arguments\n"); + Solver.addArgumentTrackedFunction(&F); + continue; + } else { + LLVM_DEBUG(dbgs() << "FnSpecialization: Can't track arguments!\n" + << "FnSpecialization: Doesn't have local linkage, or " + << "has its address taken\n"); + } + + // Assume the function is called. + Solver.markBlockExecutable(&F.front()); + + // Assume nothing about the incoming arguments. + for (Argument &AI : F.args()) + Solver.markOverdefined(&AI); + } + + // Determine if we can track any of the module's global variables. If so, add + // the global variables we can track to the solver's set of tracked global + // variables. + for (GlobalVariable &G : M.globals()) { + G.removeDeadConstantUsers(); + if (canTrackGlobalVariableInterprocedurally(&G)) + Solver.trackValueOfGlobalVariable(&G); + } + + // Solve for constants. + auto RunSCCPSolver = [&](auto &WorkList) { + bool ResolvedUndefs = true; + + while (ResolvedUndefs) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Running solver\n"); + Solver.solve(); + LLVM_DEBUG(dbgs() << "FnSpecialization: Resolving undefs\n"); + ResolvedUndefs = false; + for (Function *F : WorkList) + if (Solver.resolvedUndefsIn(*F)) + ResolvedUndefs = true; + } + + for (auto *F : WorkList) { + for (BasicBlock &BB : *F) { + if (!Solver.isBlockExecutable(&BB)) + continue; + for (auto &I : make_early_inc_range(BB)) + FS.tryToReplaceWithConstant(&I); + } + } + }; + + auto &TrackedFuncs = Solver.getArgumentTrackedFunctions(); + SmallVector FuncDecls(TrackedFuncs.begin(), + TrackedFuncs.end()); +#ifndef NDEBUG + LLVM_DEBUG(dbgs() << "FnSpecialization: Worklist fn decls:\n"); + for (auto *F : FuncDecls) + LLVM_DEBUG(dbgs() << "FnSpecialization: *) " << F->getName() << "\n"); +#endif + + // Initially resolve the constants in all the argument tracked functions. + RunSCCPSolver(FuncDecls); + + SmallVector CurrentSpecializations; + unsigned I = 0; + while (FuncSpecializationMaxIters != I++ && + FS.specializeFunctions(FuncDecls, CurrentSpecializations)) { + // TODO: run the solver here for the specialized functions only if we want + // to specialize recursively. + + CurrentSpecializations.clear(); + Changed = true; + } + + // Clean up the IR by removing ssa_copy intrinsics. + cleanup(M); + + return Changed; +} diff --git a/llvm/lib/Transforms/Scalar/SCCP.cpp b/llvm/lib/Transforms/Scalar/SCCP.cpp --- a/llvm/lib/Transforms/Scalar/SCCP.cpp +++ b/llvm/lib/Transforms/Scalar/SCCP.cpp @@ -97,9 +97,6 @@ return !LV.isUnknownOrUndef() && !isConstant(LV); } - - - static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { Constant *Const = nullptr; if (V->getType()->isStructTy()) { @@ -162,7 +159,7 @@ if (tryToReplaceWithConstant(Solver, &Inst)) { if (Inst.isSafeToRemove()) Inst.eraseFromParent(); - // Hey, we just changed something! + MadeChanges = true; ++InstRemovedStat; } else if (isa(&Inst)) { diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -307,8 +307,6 @@ void visitLoadInst(LoadInst &I); void visitGetElementPtrInst(GetElementPtrInst &I); - void visitCallInst(CallInst &I) { visitCallBase(I); } - void visitInvokeInst(InvokeInst &II) { visitCallBase(II); visitTerminator(II); @@ -334,6 +332,8 @@ AnalysisResults.insert({&F, std::move(A)}); } + void visitCallInst(CallInst &I) { visitCallBase(I); } + bool markBlockExecutable(BasicBlock *BB); const PredicateBase *getPredicateInfoFor(Instruction *I) { @@ -447,6 +447,17 @@ bool isStructLatticeConstant(Function *F, StructType *STy); Constant *getConstant(const ValueLatticeElement &LV) const; + + SmallPtrSetImpl &getArgumentTrackedFunctions() { + return TrackingIncomingArguments; + } + + void markArgInFuncSpecialization(Function *F, Argument *A, Constant *C); + + void markFunctionUnreachable(Function *F) { + for (auto &BB : *F) + BBExecutable.erase(&BB); + } }; } // namespace llvm @@ -515,6 +526,25 @@ return nullptr; } +void SCCPInstVisitor::markArgInFuncSpecialization(Function *F, Argument *A, + Constant *C) { + assert(F->arg_size() == A->getParent()->arg_size() && + "Functions should have the same number of arguments"); + + // Mark the argument constant in the new function. + markConstant(A, C); + + // For the remaining arguments in the new function, copy the lattice state + // over from the old function. + for (auto I = F->arg_begin(), J = A->getParent()->arg_begin(), + E = F->arg_end(); + I != E; ++I, ++J) + if (J != A && ValueState.count(I)) { + ValueState[J] = ValueState[I]; + pushToWorkList(ValueState[J], J); + } +} + void SCCPInstVisitor::visitInstruction(Instruction &I) { // All the instructions we don't do any special handling for just // go to overdefined. @@ -1574,7 +1604,7 @@ LLVMContext &Ctx) : Visitor(new SCCPInstVisitor(DL, std::move(GetTLI), Ctx)) {} -SCCPSolver::~SCCPSolver() { } +SCCPSolver::~SCCPSolver() {} void SCCPSolver::addAnalysis(Function &F, AnalysisResultsForFn A) { return Visitor->addAnalysis(F, std::move(A)); @@ -1664,3 +1694,20 @@ Constant *SCCPSolver::getConstant(const ValueLatticeElement &LV) const { return Visitor->getConstant(LV); } + +SmallPtrSetImpl &SCCPSolver::getArgumentTrackedFunctions() { + return Visitor->getArgumentTrackedFunctions(); +} + +void SCCPSolver::markArgInFuncSpecialization(Function *F, Argument *A, + Constant *C) { + Visitor->markArgInFuncSpecialization(F, A, C); +} + +void SCCPSolver::markFunctionUnreachable(Function *F) { + Visitor->markFunctionUnreachable(F); +} + +void SCCPSolver::visit(Instruction *I) { Visitor->visit(I); } + +void SCCPSolver::visitCall(CallInst &I) { Visitor->visitCall(I); } diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll @@ -0,0 +1,56 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -function-specialization -inline -instcombine -S < %s | FileCheck %s + +; TODO: this is a case that would be interesting to support, but we don't yet +; at the moment. + +@Global = internal constant i32 1, align 4 + +define internal void @recursiveFunc(i32* nocapture readonly %arg) { +; CHECK-LABEL: @recursiveFunc( +; CHECK-NEXT: [[TEMP:%.*]] = alloca i32, align 4 +; CHECK-NEXT: [[ARG_LOAD:%.*]] = load i32, i32* [[ARG:%.*]], align 4 +; CHECK-NEXT: [[ARG_CMP:%.*]] = icmp slt i32 [[ARG_LOAD]], 4 +; CHECK-NEXT: br i1 [[ARG_CMP]], label [[BLOCK6:%.*]], label [[RET_BLOCK:%.*]] +; CHECK: block6: +; CHECK-NEXT: call void @print_val(i32 [[ARG_LOAD]]) +; CHECK-NEXT: [[ARG_ADD:%.*]] = add nsw i32 [[ARG_LOAD]], 1 +; CHECK-NEXT: store i32 [[ARG_ADD]], i32* [[TEMP]], align 4 +; CHECK-NEXT: call void @recursiveFunc(i32* nonnull [[TEMP]]) +; CHECK-NEXT: br label [[RET_BLOCK]] +; CHECK: ret.block: +; CHECK-NEXT: ret void +; + %temp = alloca i32, align 4 + %arg.load = load i32, i32* %arg, align 4 + %arg.cmp = icmp slt i32 %arg.load, 4 + br i1 %arg.cmp, label %block6, label %ret.block + +block6: + call void @print_val(i32 %arg.load) + %arg.add = add nsw i32 %arg.load, 1 + store i32 %arg.add, i32* %temp, align 4 + call void @recursiveFunc(i32* nonnull %temp) + br label %ret.block + +ret.block: + ret void +} + +define i32 @main() { +; CHECK-LABEL: @main( +; CHECK-NEXT: [[TEMP_I:%.*]] = alloca i32, align 4 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i32* [[TEMP_I]] to i8* +; CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull [[TMP1]]) +; CHECK-NEXT: call void @print_val(i32 1) +; CHECK-NEXT: store i32 2, i32* [[TEMP_I]], align 4 +; CHECK-NEXT: call void @recursiveFunc(i32* nonnull [[TEMP_I]]) +; CHECK-NEXT: [[TMP2:%.*]] = bitcast i32* [[TEMP_I]] to i8* +; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull [[TMP2]]) +; CHECK-NEXT: ret i32 0 +; + call void @recursiveFunc(i32* nonnull @Global) + ret i32 0 +} + +declare dso_local void @print_val(i32) diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization.ll @@ -0,0 +1,50 @@ +; RUN: opt -function-specialization -deadargelim -inline -S < %s | FileCheck %s + +; CHECK-LABEL: @main(i64 %x, i1 %flag) { +; CHECK: entry: +; CHECK-NEXT: br i1 %flag, label %plus, label %minus +; CHECK: plus: +; CHECK-NEXT: [[TMP0:%.+]] = add i64 %x, 1 +; CHECH-NEXT: br label %merge +; CHECK: minus: +; CHECK-NEXT: [[TMP1:%.+]] = sub i64 %x, 1 +; CHECK-NEXT: br label %merge +; CHECK: merge: +; CHECK-NEXT: [[TMP2:%.+]] = phi i64 [ [[TMP0]], %plus ], [ [[TMP1]], %minus ] +; CHECK-NEXT: ret i64 [[TMP2]] +; CHECK-NEXT: } +; +define i64 @main(i64 %x, i1 %flag) { +entry: + br i1 %flag, label %plus, label %minus + +plus: + %tmp0 = call i64 @compute(i64 %x, i64 (i64)* @plus) + br label %merge + +minus: + %tmp1 = call i64 @compute(i64 %x, i64 (i64)* @minus) + br label %merge + +merge: + %tmp2 = phi i64 [ %tmp0, %plus ], [ %tmp1, %minus] + ret i64 %tmp2 +} + +define internal i64 @compute(i64 %x, i64 (i64)* %binop) { +entry: + %tmp0 = call i64 %binop(i64 %x) + ret i64 %tmp0 +} + +define internal i64 @plus(i64 %x) { +entry: + %tmp0 = add i64 %x, 1 + ret i64 %tmp0 +} + +define internal i64 @minus(i64 %x) { +entry: + %tmp0 = sub i64 %x, 1 + ret i64 %tmp0 +} diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization2.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization2.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization2.ll @@ -0,0 +1,87 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -function-specialization -deadargelim -S < %s | FileCheck %s +; RUN: opt -function-specialization -func-specialization-max-iters=1 -deadargelim -S < %s | FileCheck %s +; RUN: opt -function-specialization -func-specialization-max-iters=0 -deadargelim -S < %s | FileCheck %s --check-prefix=DISABLED +; RUN: opt -function-specialization -func-specialization-avg-iters-cost=1 -deadargelim -S < %s | FileCheck %s + +; DISABLED-NOT: @func.1( +; DISABLED-NOT: @func.2( + +define internal i32 @func(i32* %0, i32 %1, void (i32*)* nocapture %2) { + %4 = alloca i32, align 4 + store i32 %1, i32* %4, align 4 + %5 = load i32, i32* %4, align 4 + %6 = icmp slt i32 %5, 1 + br i1 %6, label %14, label %7 + +7: ; preds = %3 + %8 = load i32, i32* %4, align 4 + %9 = sext i32 %8 to i64 + %10 = getelementptr inbounds i32, i32* %0, i64 %9 + call void %2(i32* %10) + %11 = load i32, i32* %4, align 4 + %12 = add nsw i32 %11, -1 + %13 = call i32 @func(i32* %0, i32 %12, void (i32*)* %2) + br label %14 + +14: ; preds = %3, %7 + ret i32 0 +} + +define internal void @increment(i32* nocapture %0) { + %2 = load i32, i32* %0, align 4 + %3 = add nsw i32 %2, 1 + store i32 %3, i32* %0, align 4 + ret void +} + +define internal void @decrement(i32* nocapture %0) { + %2 = load i32, i32* %0, align 4 + %3 = add nsw i32 %2, -1 + store i32 %3, i32* %0, align 4 + ret void +} + +define i32 @main(i32* %0, i32 %1) { +; CHECK: [[TMP3:%.*]] = call i32 @func.2(i32* [[TMP0:%.*]], i32 [[TMP1:%.*]]) + %3 = call i32 @func(i32* %0, i32 %1, void (i32*)* nonnull @increment) +; CHECK: [[TMP4:%.*]] = call i32 @func.1(i32* [[TMP0]], i32 [[TMP3]]) + %4 = call i32 @func(i32* %0, i32 %3, void (i32*)* nonnull @decrement) + ret i32 %4 +} + +; CHECK: @func.1( +; CHECK: [[TMP3:%.*]] = alloca i32, align 4 +; CHECK: store i32 [[TMP1:%.*]], i32* [[TMP3]], align 4 +; CHECK: [[TMP4:%.*]] = load i32, i32* [[TMP3]], align 4 +; CHECK: [[TMP5:%.*]] = icmp slt i32 [[TMP4]], 1 +; CHECK: br i1 [[TMP5]], label [[TMP13:%.*]], label [[TMP6:%.*]] +; CHECK: 6: +; CHECK: [[TMP7:%.*]] = load i32, i32* [[TMP3]], align 4 +; CHECK: [[TMP8:%.*]] = sext i32 [[TMP7]] to i64 +; CHECK: [[TMP9:%.*]] = getelementptr inbounds i32, i32* [[TMP0:%.*]], i64 [[TMP8]] +; CHECK: call void @decrement(i32* [[TMP9]]) +; CHECK: [[TMP10:%.*]] = load i32, i32* [[TMP3]], align 4 +; CHECK: [[TMP11:%.*]] = add nsw i32 [[TMP10]], -1 +; CHECK: [[TMP12:%.*]] = call i32 @func.1(i32* [[TMP0]], i32 [[TMP11]]) +; CHECK: br label [[TMP13]] +; CHECK: 13: +; CHECK: ret i32 0 +; +; +; CHECK: @func.2( +; CHECK: [[TMP3:%.*]] = alloca i32, align 4 +; CHECK: store i32 [[TMP1:%.*]], i32* [[TMP3]], align 4 +; CHECK: [[TMP4:%.*]] = load i32, i32* [[TMP3]], align 4 +; CHECK: [[TMP5:%.*]] = icmp slt i32 [[TMP4]], 1 +; CHECK: br i1 [[TMP5]], label [[TMP13:%.*]], label [[TMP6:%.*]] +; CHECK: 6: +; CHECK: [[TMP7:%.*]] = load i32, i32* [[TMP3]], align 4 +; CHECK: [[TMP8:%.*]] = sext i32 [[TMP7]] to i64 +; CHECK: [[TMP9:%.*]] = getelementptr inbounds i32, i32* [[TMP0:%.*]], i64 [[TMP8]] +; CHECK: call void @increment(i32* [[TMP9]]) +; CHECK: [[TMP10:%.*]] = load i32, i32* [[TMP3]], align 4 +; CHECK: [[TMP11:%.*]] = add nsw i32 [[TMP10]], -1 +; CHECK: [[TMP12:%.*]] = call i32 @func.2(i32* [[TMP0]], i32 [[TMP11]]) +; CHECK: br label [[TMP13]] +; CHECK: ret i32 0 diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization3.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization3.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization3.ll @@ -0,0 +1,58 @@ +; RUN: opt -function-specialization -func-specialization-avg-iters-cost=3 -S < %s | \ +; RUN: FileCheck %s --check-prefixes=COMMON,DISABLED +; RUN: opt -function-specialization -func-specialization-avg-iters-cost=4 -S < %s | \ +; RUN: FileCheck %s --check-prefixes=COMMON,FORCE +; RUN: opt -function-specialization -force-function-specialization -S < %s | \ +; RUN: FileCheck %s --check-prefixes=COMMON,FORCE +; RUN: opt -function-specialization -func-specialization-avg-iters-cost=3 -force-function-specialization -S < %s | \ +; RUN: FileCheck %s --check-prefixes=COMMON,FORCE + +; Test for specializing a constant global. + +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" + +@A = external dso_local constant i32, align 4 +@B = external dso_local constant i32, align 4 + +define dso_local i32 @bar(i32 %x, i32 %y) { +; COMMON-LABEL: @bar +; FORCE: %call = call i32 @foo.2(i32 %x, i32* @A) +; FORCE: %call1 = call i32 @foo.1(i32 %y, i32* @B) +; DISABLED-NOT: %call1 = call i32 @foo.1( +entry: + %tobool = icmp ne i32 %x, 0 + br i1 %tobool, label %if.then, label %if.else + +if.then: + %call = call i32 @foo(i32 %x, i32* @A) + br label %return + +if.else: + %call1 = call i32 @foo(i32 %y, i32* @B) + br label %return + +return: + %retval.0 = phi i32 [ %call, %if.then ], [ %call1, %if.else ] + ret i32 %retval.0 +} + +; FORCE: define internal i32 @foo.1(i32 %x, i32* %b) { +; FORCE-NEXT: entry: +; FORCE-NEXT: %0 = load i32, i32* @B, align 4 +; FORCE-NEXT: %add = add nsw i32 %x, %0 +; FORCE-NEXT: ret i32 %add +; FORCE-NEXT: } + +; FORCE: define internal i32 @foo.2(i32 %x, i32* %b) { +; FORCE-NEXT: entry: +; FORCE-NEXT: %0 = load i32, i32* @A, align 4 +; FORCE-NEXT: %add = add nsw i32 %x, %0 +; FORCE-NEXT: ret i32 %add +; FORCE-NEXT: } + +define internal i32 @foo(i32 %x, i32* %b) { +entry: + %0 = load i32, i32* %b, align 4 + %add = add nsw i32 %x, %0 + ret i32 %add +} diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization4.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization4.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization4.ll @@ -0,0 +1,60 @@ +; RUN: opt -function-specialization -force-function-specialization \ +; RUN: -func-specialization-max-constants=2 -S < %s | FileCheck %s + +; RUN: opt -function-specialization -force-function-specialization \ +; RUN: -func-specialization-max-constants=1 -S < %s | FileCheck %s --check-prefix=CONST1 + +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" + +@A = external dso_local constant i32, align 4 +@B = external dso_local constant i32, align 4 +@C = external dso_local constant i32, align 4 +@D = external dso_local constant i32, align 4 + +define dso_local i32 @bar(i32 %x, i32 %y) { +entry: + %tobool = icmp ne i32 %x, 0 + br i1 %tobool, label %if.then, label %if.else + +if.then: + %call = call i32 @foo(i32 %x, i32* @A, i32* @C) + br label %return + +if.else: + %call1 = call i32 @foo(i32 %y, i32* @B, i32* @D) + br label %return + +return: + %retval.0 = phi i32 [ %call, %if.then ], [ %call1, %if.else ] + ret i32 %retval.0 +} + +define internal i32 @foo(i32 %x, i32* %b, i32* %c) { +entry: + %0 = load i32, i32* %b, align 4 + %add = add nsw i32 %x, %0 + %1 = load i32, i32* %c, align 4 + %add1 = add nsw i32 %add, %1 + ret i32 %add1 +} + +; CONST1-NOT: define internal i32 @foo.1(i32 %x, i32* %b, i32* %c) +; CONST1-NOT: define internal i32 @foo.2(i32 %x, i32* %b, i32* %c) + +; CHECK: define internal i32 @foo.1(i32 %x, i32* %b, i32* %c) { +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = load i32, i32* @B, align 4 +; CHECK-NEXT: %add = add nsw i32 %x, %0 +; CHECK-NEXT: %1 = load i32, i32* %c, align 4 +; CHECK-NEXT: %add1 = add nsw i32 %add, %1 +; CHECK-NEXT: ret i32 %add1 +; CHECK-NEXT: } + +; CHECK: define internal i32 @foo.2(i32 %x, i32* %b, i32* %c) { +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = load i32, i32* @A, align 4 +; CHECK-NEXT: %add = add nsw i32 %x, %0 +; CHECK-NEXT: %1 = load i32, i32* %c, align 4 +; CHECK-NEXT: %add1 = add nsw i32 %add, %1 +; CHECK-NEXT: ret i32 %add1 +; CHECK-NEXT: } diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization5.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization5.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization5.ll @@ -0,0 +1,40 @@ +; RUN: opt -function-specialization -force-function-specialization -S < %s | FileCheck %s + +; There's nothing to specialize here as both calls are the same, so check that: +; +; CHECK-NOT: define internal i32 @foo.1( +; CHECK-NOT: define internal i32 @foo.2( + +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" + +@A = external dso_local constant i32, align 4 +@B = external dso_local constant i32, align 4 +@C = external dso_local constant i32, align 4 +@D = external dso_local constant i32, align 4 + +define dso_local i32 @bar(i32 %x, i32 %y) { +entry: + %tobool = icmp ne i32 %x, 0 + br i1 %tobool, label %if.then, label %if.else + +if.then: + %call = call i32 @foo(i32 %x, i32* @A, i32* @C) + br label %return + +if.else: + %call1 = call i32 @foo(i32 %y, i32* @A, i32* @C) + br label %return + +return: + %retval.0 = phi i32 [ %call, %if.then ], [ %call1, %if.else ] + ret i32 %retval.0 +} + +define internal i32 @foo(i32 %x, i32* %b, i32* %c) { +entry: + %0 = load i32, i32* %b, align 4 + %add = add nsw i32 %x, %0 + %1 = load i32, i32* %c, align 4 + %add1 = add nsw i32 %add, %1 + ret i32 %add1 +}