diff --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h --- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h +++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h @@ -34,6 +34,14 @@ PostDominatorTree *PDT; }; +/// Helper struct shared between Function Specialization and SCCP Solver. +struct ArgInfo { + Argument *Formal; // The Formal argument being analysed. + Constant *Actual; // A corresponding actual constant argument. + + ArgInfo(Argument *F, Constant *A) : Formal(F), Actual(A) {}; +}; + class SCCPInstVisitor; //===----------------------------------------------------------------------===// @@ -138,7 +146,7 @@ /// specialization. The argument's parent function is a specialization of the /// original function \p F. All other arguments of the specialization inherit /// the lattice state of their corresponding values in the original function. - void markArgInFuncSpecialization(Function *F, Argument *A, Constant *C); + void markArgInFuncSpecialization(Function *F, ArgInfo &Arg); /// Mark all of the blocks in function \p F non-executable. Clients can used /// this method to erase a function from the module (e.g., if it has been diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -108,24 +108,18 @@ namespace { // Bookkeeping struct to pass data from the analysis and profitability phase // to the actual transform helper functions. -struct ArgInfo { - Function *Fn; // The function to perform specialisation on. - Argument *Formal; // The Formal argument being analysed. - Constant *Actual; // A corresponding actual constant argument. +struct SpecializationInfo { + ArgInfo Arg; // Stores the {formal,actual} argument pair. InstructionCost Gain; // Profitability: Gain = Bonus - Cost. - // Flag if this will be a partial specialization, in which case we will need - // to keep the original function around in addition to the added - // specializations. - bool Partial = false; - - ArgInfo(Function *F, Argument *A, Constant *C, InstructionCost G) - : Fn(F), Formal(A), Actual(C), Gain(G){}; + SpecializationInfo(Argument *A, Constant *C, InstructionCost G) + : Arg(A, C), Gain(G) {}; }; } // Anonymous namespace using FuncList = SmallVectorImpl; -using ConstList = SmallVectorImpl; +using ConstList = SmallVector; +using SpecializationList = SmallVector; // Helper to check if \p LV is either a constant or a constant // range with a single element. This should cover exactly the same cases as the @@ -312,14 +306,14 @@ LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for " << F->getName() << " is " << Cost << "\n"); - auto ConstArgs = calculateGains(F, Cost); - if (ConstArgs.empty()) { + SpecializationList Specializations = calculateGains(F, Cost); + if (Specializations.empty()) { LLVM_DEBUG(dbgs() << "FnSpecialization: no possible constants found\n"); continue; } - for (auto &CA : ConstArgs) { - specializeFunction(CA, WorkList); + for (SpecializationInfo &S : Specializations) { + specializeFunction(F, S, WorkList); Changed = true; } } @@ -390,9 +384,8 @@ /// Clone the function \p F and remove the ssa_copy intrinsics added by /// the SCCPSolver in the cloned version. - Function *cloneCandidateFunction(Function *F) { - ValueToValueMapTy EmptyMap; - Function *Clone = CloneFunction(F, EmptyMap); + Function *cloneCandidateFunction(Function *F, ValueToValueMapTy &Mappings) { + Function *Clone = CloneFunction(F, Mappings); removeSSACopy(*Clone); return Clone; } @@ -403,8 +396,8 @@ /// profitable to specialize. Specialization is performed on the first /// interesting argument. Specializations based on additional arguments will /// be evaluated on following iterations of the main IPSCCP solve loop. - SmallVector calculateGains(Function *F, InstructionCost Cost) { - SmallVector Worklist; + SpecializationList calculateGains(Function *F, InstructionCost Cost) { + SpecializationList WorkList; // Determine if we should specialize the function based on the values the // argument can take on. If specialization is not profitable, we continue // on to the next argument. @@ -416,7 +409,7 @@ // be set to false by isArgumentInteresting (that function only adds // values to the Constants list that are deemed profitable). bool IsPartial = true; - SmallVector ActualArgs; + ConstList ActualArgs; if (!isArgumentInteresting(&FormalArg, ActualArgs, IsPartial)) { LLVM_DEBUG(dbgs() << "FnSpecialization: Argument " << FormalArg.getNameOrAsOperand() @@ -432,47 +425,44 @@ if (Gain <= 0) continue; - Worklist.push_back({F, &FormalArg, ActualArg, Gain}); + WorkList.push_back({&FormalArg, ActualArg, Gain}); } - if (Worklist.empty()) + if (WorkList.empty()) continue; // Sort the candidates in descending order. - llvm::stable_sort(Worklist, [](const ArgInfo &L, const ArgInfo &R) { + llvm::stable_sort(WorkList, [](const SpecializationInfo &L, + const SpecializationInfo &R) { return L.Gain > R.Gain; }); // Truncate the worklist to 'MaxClonesThreshold' candidates if // necessary. - if (Worklist.size() > MaxClonesThreshold) { + if (WorkList.size() > MaxClonesThreshold) { LLVM_DEBUG(dbgs() << "FnSpecialization: Number of candidates exceed " << "the maximum number of clones threshold.\n" << "FnSpecialization: Truncating worklist to " << MaxClonesThreshold << " candidates.\n"); - Worklist.erase(Worklist.begin() + MaxClonesThreshold, - Worklist.end()); + WorkList.erase(WorkList.begin() + MaxClonesThreshold, + WorkList.end()); } - if (IsPartial || Worklist.size() < ActualArgs.size()) - for (auto &ActualArg : Worklist) - ActualArg.Partial = true; - LLVM_DEBUG( dbgs() << "FnSpecialization: Specializations for function " << F->getName() << "\n"; - for (auto &C : Worklist) { + for (SpecializationInfo &S : WorkList) { dbgs() << "FnSpecialization: FormalArg = " - << C.Formal->getNameOrAsOperand() << ", ActualArg = " - << C.Actual->getNameOrAsOperand() << ", Gain = " - << C.Gain << "\n"; + << S.Arg.Formal->getNameOrAsOperand() << ", ActualArg = " + << S.Arg.Actual->getNameOrAsOperand() << ", Gain = " + << S.Gain << "\n"; } ); // FIXME: Only one argument per function. break; } - return Worklist; + return WorkList; } bool isCandidateFunction(Function *F) { @@ -499,17 +489,18 @@ return true; } - void specializeFunction(ArgInfo &AI, FuncList &WorkList) { - Function *Clone = cloneCandidateFunction(AI.Fn); - Argument *ClonedArg = Clone->getArg(AI.Formal->getArgNo()); + void specializeFunction(Function *F, SpecializationInfo &S, + FuncList &WorkList) { + ValueToValueMapTy Mappings; + Function *Clone = cloneCandidateFunction(F, Mappings); // Rewrite calls to the function so that they call the clone instead. - rewriteCallSites(AI.Fn, Clone, *ClonedArg, AI.Actual); + rewriteCallSites(Clone, S.Arg, Mappings); // Initialize the lattice state of the arguments of the function clone, // marking the argument on which we specialized the function constant // with the given value. - Solver.markArgInFuncSpecialization(AI.Fn, ClonedArg, AI.Actual); + Solver.markArgInFuncSpecialization(Clone, S.Arg); // Mark all the specialized functions WorkList.push_back(Clone); @@ -517,14 +508,14 @@ // If the function has been completely specialized, the original function // is no longer needed. Mark it unreachable. - if (AI.Fn->getNumUses() == 0 || - all_of(AI.Fn->users(), [&AI](User *U) { + if (F->getNumUses() == 0 || + all_of(F->users(), [F](User *U) { if (auto *CS = dyn_cast(U)) - return CS->getFunction() == AI.Fn; + return CS->getFunction() == F; return false; })) { - Solver.markFunctionUnreachable(AI.Fn); - FullySpecialized.insert(AI.Fn); + Solver.markFunctionUnreachable(F); + FullySpecialized.insert(F); } } @@ -768,15 +759,16 @@ /// Rewrite calls to function \p F to call function \p Clone instead. /// - /// This function modifies calls to function \p F whose argument at index \p - /// ArgNo is equal to constant \p C. The calls are rewritten to call function - /// \p Clone instead. + /// This function modifies calls to function \p F as long as the actual + /// argument matches that in \p Arg. Note that for recursive calls we + /// need to compare against the cloned formal argument. /// /// Callsites that have been marked with the MinSize function attribute won't /// be specialized and rewritten. - void rewriteCallSites(Function *F, Function *Clone, Argument &Arg, - Constant *C) { - unsigned ArgNo = Arg.getArgNo(); + void rewriteCallSites(Function *Clone, ArgInfo &Arg, + ValueToValueMapTy &Mappings) { + Function *F = Arg.Formal->getParent(); + unsigned ArgNo = Arg.Formal->getArgNo(); SmallVector CallSitesToRewrite; for (auto *U : F->users()) { if (!isa(U) && !isa(U)) @@ -795,8 +787,11 @@ LLVM_DEBUG(dbgs() << "FnSpecialization: " << CS->getFunction()->getName() << " ->" << *CS << "\n"); - if ((CS->getFunction() == Clone && CS->getArgOperand(ArgNo) == &Arg) || - CS->getArgOperand(ArgNo) == C) { + if (/* recursive call */ + (CS->getFunction() == Clone && + CS->getArgOperand(ArgNo) == Mappings[Arg.Formal]) || + /* normal call */ + CS->getArgOperand(ArgNo) == Arg.Actual) { CS->setCalledFunction(Clone); Solver.markOverdefined(CS); } diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -452,7 +452,7 @@ return TrackingIncomingArguments; } - void markArgInFuncSpecialization(Function *F, Argument *A, Constant *C); + void markArgInFuncSpecialization(Function *F, ArgInfo &Arg); void markFunctionUnreachable(Function *F) { for (auto &BB : *F) @@ -526,24 +526,24 @@ return nullptr; } -void SCCPInstVisitor::markArgInFuncSpecialization(Function *F, Argument *A, - Constant *C) { - assert(F->arg_size() == A->getParent()->arg_size() && +void SCCPInstVisitor::markArgInFuncSpecialization(Function *F, ArgInfo &Arg) { + assert(F->arg_size() == Arg.Formal->getParent()->arg_size() && "Functions should have the same number of arguments"); - // Mark the argument constant in the new function. - markConstant(A, C); - - // For the remaining arguments in the new function, copy the lattice state - // over from the old function. - for (Argument *OldArg = F->arg_begin(), *NewArg = A->getParent()->arg_begin(), - *End = F->arg_end(); - OldArg != End; ++OldArg, ++NewArg) { + Argument *NewArg = F->arg_begin(); + Argument *OldArg = Arg.Formal->getParent()->arg_begin(); + for (auto End = F->arg_end(); NewArg != End; ++NewArg, ++OldArg) { LLVM_DEBUG(dbgs() << "SCCP: Marking argument " << NewArg->getNameOrAsOperand() << "\n"); - if (NewArg != A && ValueState.count(OldArg)) { + if (OldArg == Arg.Formal) { + // Mark the argument constants in the new function. + markConstant(NewArg, Arg.Actual); + } else if (ValueState.count(OldArg)) { + // For the remaining arguments in the new function, copy the lattice state + // over from the old function. + // // Note: This previously looked like this: // ValueState[NewArg] = ValueState[OldArg]; // This is incorrect because the DenseMap class may resize the underlying @@ -1718,9 +1718,8 @@ return Visitor->getArgumentTrackedFunctions(); } -void SCCPSolver::markArgInFuncSpecialization(Function *F, Argument *A, - Constant *C) { - Visitor->markArgInFuncSpecialization(F, A, C); +void SCCPSolver::markArgInFuncSpecialization(Function *F, ArgInfo &Arg) { + Visitor->markArgInFuncSpecialization(F, Arg); } void SCCPSolver::markFunctionUnreachable(Function *F) {