diff --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h --- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h +++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h @@ -49,6 +49,10 @@ Constant *Actual; // A corresponding actual constant argument. ArgInfo(Argument *F, Constant *A) : Formal(F), Actual(A){}; + + friend hash_code hash_value(const ArgInfo &Info) { + return hash_value(std::make_pair(Info.Formal, Info.Actual)); + } }; class SCCPInstVisitor; diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -321,8 +321,7 @@ } Changed = true; - for (auto &Entry : Specializations) - specializeFunction(F, Entry.second, WorkList); + createSpecializations(Specializations, F, WorkList); } updateSpecializedFuncs(Candidates, WorkList); @@ -514,32 +513,57 @@ return true; } - void specializeFunction(Function *F, SpecializationInfo &S, - FuncList &WorkList) { - ValueToValueMapTy Mappings; - Function *Clone = cloneCandidateFunction(F, Mappings); - - // Rewrite calls to the function so that they call the clone instead. - rewriteCallSites(Clone, S.Args, Mappings); - - // Initialize the lattice state of the arguments of the function clone, - // marking the argument on which we specialized the function constant - // with the given value. - Solver.markArgInFuncSpecialization(Clone, S.Args); - - // Mark all the specialized functions - WorkList.push_back(Clone); - NbFunctionsSpecialized++; - - // If the function has been completely specialized, the original function - // is no longer needed. Mark it unreachable. - if (F->getNumUses() == 0 || all_of(F->users(), [F](User *U) { - if (auto *CS = dyn_cast(U)) - return CS->getFunction() == F; - return false; - })) { - Solver.markFunctionUnreachable(F); - FullySpecialized.insert(F); + void createSpecializations(SmallVectorImpl &Specializations, + Function *F, FuncList &WorkList) { + // Make sure we don't create multiple copies for the same specialization. + SmallDenseMap Clones; + + for (const auto &Entry : Specializations) { + CallBase *Call = Entry.first; + const SpecializationInfo &Info = Entry.second; + + // Create a copy of the function if it doesn't already exist. + hash_code Key = hash_value((ArrayRef)Info.Args); + auto Pair = Clones.insert({Key, nullptr}); + Function *&Clone = Pair.first->second; + if (Pair.second) { + ValueToValueMapTy Mappings; + Clone = cloneCandidateFunction(F, Mappings); + // Update recursive calls if their parameters are the same as the + // formal arguments of the specialization. + for (User *U : F->users()) + if (auto CS = dyn_cast(U)) + if (CS->getFunction() == Clone && + all_of(Info.Args, [CS, &Mappings](const ArgInfo &Arg) { + unsigned ArgNo = Arg.Formal->getArgNo(); + return CS->getArgOperand(ArgNo) == Mappings[Arg.Formal]; + })) + CS->setCalledFunction(Clone); + } + + // Update the call site. + Call->setCalledFunction(Clone); + Solver.markOverdefined(Call); + + // Initialize the lattice state of the arguments of the function clone, + // marking the argument on which we specialized the function constant + // with the given value. + Solver.markArgInFuncSpecialization(Clone, Info.Args); + + // Mark all the specialized functions + WorkList.push_back(Clone); + NbFunctionsSpecialized++; + + // If the function has been completely specialized, the original function + // is no longer needed. Mark it unreachable. + if (F->getNumUses() == 0 || all_of(F->users(), [F](User *U) { + if (auto *CS = dyn_cast(U)) + return CS->getFunction() == F; + return false; + })) { + Solver.markFunctionUnreachable(F); + FullySpecialized.insert(F); + } } } @@ -753,54 +777,6 @@ } } - /// Rewrite calls to function \p F to call function \p Clone instead. - /// - /// This function modifies calls to function \p F as long as the actual - /// arguments match those in \p Args. Note that for recursive calls we - /// need to compare against the cloned formal arguments. - /// - /// Callsites that have been marked with the MinSize function attribute won't - /// be specialized and rewritten. - void rewriteCallSites(Function *Clone, const SmallVectorImpl &Args, - ValueToValueMapTy &Mappings) { - assert(!Args.empty() && "Specialization without arguments"); - Function *F = Args[0].Formal->getParent(); - - SmallVector CallSitesToRewrite; - for (auto *U : F->users()) { - if (!isa(U) && !isa(U)) - continue; - auto &CS = *cast(U); - if (!CS.getCalledFunction() || CS.getCalledFunction() != F) - continue; - CallSitesToRewrite.push_back(&CS); - } - - LLVM_DEBUG(dbgs() << "FnSpecialization: Replacing call sites of " - << F->getName() << " with " << Clone->getName() << "\n"); - - for (auto *CS : CallSitesToRewrite) { - LLVM_DEBUG(dbgs() << "FnSpecialization: " - << CS->getFunction()->getName() << " ->" << *CS - << "\n"); - if (/* recursive call */ - (CS->getFunction() == Clone && - all_of(Args, - [CS, &Mappings](const ArgInfo &Arg) { - unsigned ArgNo = Arg.Formal->getArgNo(); - return CS->getArgOperand(ArgNo) == Mappings[Arg.Formal]; - })) || - /* normal call */ - all_of(Args, [CS](const ArgInfo &Arg) { - unsigned ArgNo = Arg.Formal->getArgNo(); - return CS->getArgOperand(ArgNo) == Arg.Actual; - })) { - CS->setCalledFunction(Clone); - Solver.markOverdefined(CS); - } - } - } - void updateSpecializedFuncs(FuncList &Candidates, FuncList &WorkList) { for (auto *F : WorkList) { SpecializedFuncs.insert(F); diff --git a/llvm/test/Transforms/FunctionSpecialization/identical-specializations.ll b/llvm/test/Transforms/FunctionSpecialization/identical-specializations.ll --- a/llvm/test/Transforms/FunctionSpecialization/identical-specializations.ll +++ b/llvm/test/Transforms/FunctionSpecialization/identical-specializations.ll @@ -69,8 +69,4 @@ ; CHECK-NEXT: [[CMP0:%.*]] = call i64 @minus(i64 [[X:%.*]], i64 [[Y:%.*]]) ; CHECK-NEXT: [[CMP1:%.*]] = call i64 @plus(i64 [[X]], i64 [[Y]]) -; CHECK-LABEL: @compute.3 -; CHECK-NEXT: entry: -; CHECK-NEXT: [[CMP0:%.*]] = call i64 @plus(i64 [[X:%.*]], i64 [[Y:%.*]]) -; CHECK-NEXT: [[CMP1:%.*]] = call i64 @minus(i64 [[X]], i64 [[Y]]) - +; CHECK-NOT: define internal i64 @compute.3(