diff --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h --- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h +++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h @@ -171,14 +171,9 @@ /// Return a reference to the set of argument tracked functions. SmallPtrSetImpl &getArgumentTrackedFunctions(); - /// Mark the constant arguments of a new function specialization. \p F points - /// to the cloned function and \p Args contains a list of constant arguments - /// represented as pairs of {formal,actual} values (the formal argument is - /// associated with the original function definition). All other arguments of - /// the specialization inherit the lattice state of their corresponding values - /// in the original function. - void markArgInFuncSpecialization(Function *F, - const SmallVectorImpl &Args); + /// Computes the lattice value for each argument of the function called by + /// \p CB + void handleCallArguments(CallBase &CB); /// Mark all of the blocks in function \p F non-executable. Clients can used /// this method to erase a function from the module (e.g., if it has been diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -343,11 +343,18 @@ Spec &S = AllSpecs[BestSpecs[I]]; S.Clone = createSpecialization(S.F, S.Sig); + // Visiting the callsites of an argument tracked function will mark + // its entry basic block as executable. Since there are no known + // callsites at this point, we need to do this manually. + if (S.CallSites.empty()) + Solver.markBlockExecutable(&S.Clone->front()); + // Update the known call sites to call the clone. for (CallBase *Call : S.CallSites) { LLVM_DEBUG(dbgs() << "FnSpecialization: Redirecting " << *Call << " to call " << S.Clone->getName() << "\n"); Call->setCalledFunction(S.Clone); + Solver.handleCallArguments(*Call); } Clones.push_back(S.Clone); @@ -533,13 +540,7 @@ Function *FunctionSpecializer::createSpecialization(Function *F, const SpecSig &S) { Function *Clone = cloneCandidateFunction(F); - // Initialize the lattice state of the arguments of the function clone, - // marking the argument on which we specialized the function constant - // with the given value. - Solver.markArgInFuncSpecialization(Clone, S.Args); - Solver.addArgumentTrackedFunction(Clone); - Solver.markBlockExecutable(&Clone->front()); // Mark all the specialized functions Specializations.insert(Clone); @@ -768,6 +769,7 @@ LLVM_DEBUG(dbgs() << "FnSpecialization: Redirecting " << *CS << " to call " << BestSpec->Clone->getName() << "\n"); CS->setCalledFunction(BestSpec->Clone); + Solver.handleCallArguments(*CS); ShouldDecrementCount = true; } diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -551,7 +551,6 @@ } void handleCallOverdefined(CallBase &CB); void handleCallResult(CallBase &CB); - void handleCallArguments(CallBase &CB); void handleExtractOfWithOverflow(ExtractValueInst &EVI, const WithOverflowInst *WO, unsigned Idx); @@ -738,8 +737,7 @@ return TrackingIncomingArguments; } - void markArgInFuncSpecialization(Function *F, - const SmallVectorImpl &Args); + void handleCallArguments(CallBase &CB); void markFunctionUnreachable(Function *F) { for (auto &BB : *F) @@ -833,40 +831,6 @@ return nullptr; } -void SCCPInstVisitor::markArgInFuncSpecialization( - Function *F, const SmallVectorImpl &Args) { - assert(!Args.empty() && "Specialization without arguments"); - assert(F->arg_size() == Args[0].Formal->getParent()->arg_size() && - "Functions should have the same number of arguments"); - - auto Iter = Args.begin(); - Argument *NewArg = F->arg_begin(); - Argument *OldArg = Args[0].Formal->getParent()->arg_begin(); - for (auto End = F->arg_end(); NewArg != End; ++NewArg, ++OldArg) { - - LLVM_DEBUG(dbgs() << "SCCP: Marking argument " - << NewArg->getNameOrAsOperand() << "\n"); - - if (Iter != Args.end() && OldArg == Iter->Formal) { - // Mark the argument constants in the new function. - markConstant(NewArg, Iter->Actual); - ++Iter; - } else if (ValueState.count(OldArg)) { - // For the remaining arguments in the new function, copy the lattice state - // over from the old function. - // - // Note: This previously looked like this: - // ValueState[NewArg] = ValueState[OldArg]; - // This is incorrect because the DenseMap class may resize the underlying - // memory when inserting `NewArg`, which will invalidate the reference to - // `OldArg`. Instead, we make sure `NewArg` exists before setting it. - auto &NewValue = ValueState[NewArg]; - NewValue = ValueState[OldArg]; - pushToWorkList(NewValue, NewArg); - } - } -} - void SCCPInstVisitor::visitInstruction(Instruction &I) { // All the instructions we don't do any special handling for just // go to overdefined. @@ -1949,9 +1913,8 @@ return Visitor->getArgumentTrackedFunctions(); } -void SCCPSolver::markArgInFuncSpecialization( - Function *F, const SmallVectorImpl &Args) { - Visitor->markArgInFuncSpecialization(F, Args); +void SCCPSolver::handleCallArguments(CallBase &CB) { + Visitor->handleCallArguments(CB); } void SCCPSolver::markFunctionUnreachable(Function *F) { diff --git a/llvm/test/Transforms/FunctionSpecialization/identical-specializations.ll b/llvm/test/Transforms/FunctionSpecialization/identical-specializations.ll --- a/llvm/test/Transforms/FunctionSpecialization/identical-specializations.ll +++ b/llvm/test/Transforms/FunctionSpecialization/identical-specializations.ll @@ -37,7 +37,7 @@ entry: %op0 = call i64 %binop1(i64 %x, i64 %y) %op1 = call i64 %binop2(i64 %x, i64 %y) - %op2 = call i64 @compute(i64 %x, i64 %y, ptr %binop1, ptr @plus) + %op2 = call i64 @compute(i64 %x, i64 42, ptr %binop1, ptr @plus) %add0 = add i64 %op0, %op1 %add1 = add i64 %add0, %op2 %div = sdiv i64 %add1, %x @@ -62,18 +62,37 @@ ; CHECK-LABEL: @compute.1 ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CMP0:%.*]] = call i64 %binop1(i64 [[X:%.*]], i64 [[Y:%.*]]) -; CHECK-NEXT: [[CMP1:%.*]] = call i64 @plus(i64 [[X]], i64 [[Y]]) -; CHECK-NEXT: [[CMP2:%.*]] = call i64 @compute.1(i64 [[X]], i64 [[Y]], ptr %binop1, ptr @plus) +; CHECK-NEXT: [[OP0:%.*]] = call i64 %binop1(i64 [[X:%.*]], i64 42) +; CHECK-NEXT: [[OP1:%.*]] = call i64 @plus(i64 [[X]], i64 42) +; CHECK-NEXT: [[OP2:%.*]] = call i64 @compute.1(i64 [[X]], i64 42, ptr %binop1, ptr @plus) +; CHECK-NEXT: [[ADD0:%.*]] = add i64 [[OP0]], [[OP1]] +; CHECK-NEXT: [[ADD1:%.*]] = add i64 [[ADD0]], [[OP2]] +; CHECK-NEXT: [[DIV:%.*]] = sdiv i64 [[ADD1]], [[X]] +; CHECK-NEXT: [[SUB:%.*]] = sub i64 [[DIV]], 42 +; CHECK-NEXT: [[MUL:%.*]] = mul i64 [[SUB]], 2 +; CHECK-NEXT: ret i64 [[MUL]] ; CHECK-LABEL: @compute.2 ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CMP0:%.*]] = call i64 @plus(i64 [[X:%.*]], i64 [[Y:%.*]]) -; CHECK-NEXT: [[CMP1:%.*]] = call i64 @minus(i64 [[X]], i64 [[Y]]) -; CHECK-NEXT: [[CMP2:%.*]] = call i64 @compute.1(i64 [[X]], i64 [[Y]], ptr @plus, ptr @plus) +; CHECK-NEXT: [[OP0:%.*]] = call i64 @plus(i64 [[X:%.*]], i64 [[Y:%.*]]) +; CHECK-NEXT: [[OP1:%.*]] = call i64 @minus(i64 [[X]], i64 [[Y]]) +; CHECK-NEXT: [[OP2:%.*]] = call i64 @compute.1(i64 [[X]], i64 42, ptr @plus, ptr @plus) +; CHECK-NEXT: [[ADD0:%.*]] = add i64 [[OP0]], [[OP1]] +; CHECK-NEXT: [[ADD1:%.*]] = add i64 [[ADD0]], [[OP2]] +; CHECK-NEXT: [[DIV:%.*]] = sdiv i64 [[ADD1]], [[X]] +; CHECK-NEXT: [[SUB:%.*]] = sub i64 [[DIV]], [[Y]] +; CHECK-NEXT: [[MUL:%.*]] = mul i64 [[SUB]], 2 +; CHECK-NEXT: ret i64 [[MUL]] ; CHECK-LABEL: @compute.3 ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CMP0:%.*]] = call i64 @minus(i64 [[X:%.*]], i64 [[Y:%.*]]) -; CHECK-NEXT: [[CMP1:%.*]] = call i64 @plus(i64 [[X]], i64 [[Y]]) -; CHECK-NEXT: [[CMP2:%.*]] = call i64 @compute.3(i64 [[X]], i64 [[Y]], ptr @minus, ptr @plus) +; CHECK-NEXT: [[OP0:%.*]] = call i64 @minus(i64 [[X:%.*]], i64 [[Y:%.*]]) +; CHECK-NEXT: [[OP1:%.*]] = call i64 @plus(i64 [[X]], i64 [[Y]]) +; CHECK-NEXT: [[OP2:%.*]] = call i64 @compute.3(i64 [[X]], i64 42, ptr @minus, ptr @plus) +; CHECK-NEXT: [[ADD0:%.*]] = add i64 [[OP0]], [[OP1]] +; CHECK-NEXT: [[ADD1:%.*]] = add i64 [[ADD0]], [[OP2]] +; CHECK-NEXT: [[DIV:%.*]] = sdiv i64 [[ADD1]], [[X]] +; CHECK-NEXT: [[SUB:%.*]] = sub i64 [[DIV]], [[Y]] +; CHECK-NEXT: [[MUL:%.*]] = mul i64 [[SUB]], 2 +; CHECK-NEXT: ret i64 [[MUL]] + diff --git a/llvm/test/Transforms/FunctionSpecialization/specialize-multiple-arguments.ll b/llvm/test/Transforms/FunctionSpecialization/specialize-multiple-arguments.ll --- a/llvm/test/Transforms/FunctionSpecialization/specialize-multiple-arguments.ll +++ b/llvm/test/Transforms/FunctionSpecialization/specialize-multiple-arguments.ll @@ -116,11 +116,11 @@ ; ; THREE-LABEL: define internal i64 @compute.3(i64 %x, i64 %y, ptr %binop1, ptr %binop2) { ; THREE-NEXT: entry: -; THREE-NEXT: [[TMP0:%.+]] = call i64 @minus(i64 %x, i64 %y) -; THREE-NEXT: [[TMP1:%.+]] = call i64 @power(i64 %x, i64 %y) +; THREE-NEXT: [[TMP0:%.+]] = call i64 @minus(i64 %x, i64 42) +; THREE-NEXT: [[TMP1:%.+]] = call i64 @power(i64 %x, i64 42) ; THREE-NEXT: [[TMP2:%.+]] = add i64 [[TMP0]], [[TMP1]] ; THREE-NEXT: [[TMP3:%.+]] = sdiv i64 [[TMP2]], %x -; THREE-NEXT: [[TMP4:%.+]] = sub i64 [[TMP3]], %y +; THREE-NEXT: [[TMP4:%.+]] = sub i64 [[TMP3]], 42 ; THREE-NEXT: [[TMP5:%.+]] = mul i64 [[TMP4]], 2 ; THREE-NEXT: ret i64 [[TMP5]] ; THREE-NEXT: }