diff --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h --- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h +++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h @@ -142,6 +142,10 @@ void removeLatticeValueFor(Value *V); + /// Invalidate the Lattive Value of \p Call and its users after specializing + /// the call. Then recompute it. + void resetLatticeValueFor(CallBase *Call); + const ValueLatticeElement &getLatticeValueFor(Value *V) const; /// getTrackedRetVals - Get the inferred return value map. @@ -171,15 +175,6 @@ /// Return a reference to the set of argument tracked functions. SmallPtrSetImpl &getArgumentTrackedFunctions(); - /// Mark the constant arguments of a new function specialization. \p F points - /// to the cloned function and \p Args contains a list of constant arguments - /// represented as pairs of {formal,actual} values (the formal argument is - /// associated with the original function definition). All other arguments of - /// the specialization inherit the lattice state of their corresponding values - /// in the original function. - void markArgInFuncSpecialization(Function *F, - const SmallVectorImpl &Args); - /// Mark all of the blocks in function \p F non-executable. Clients can used /// this method to erase a function from the module (e.g., if it has been /// completely specialized and is no longer needed). diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -343,11 +343,18 @@ Spec &S = AllSpecs[BestSpecs[I]]; S.Clone = createSpecialization(S.F, S.Sig); + // Visiting the callsites of an argument tracked function will mark + // its entry basic block as executable. Since there are no known + // callsites at this point, we need to do this manually. + if (S.CallSites.empty()) + Solver.markBlockExecutable(&S.Clone->front()); + // Update the known call sites to call the clone. for (CallBase *Call : S.CallSites) { LLVM_DEBUG(dbgs() << "FnSpecialization: Redirecting " << *Call << " to call " << S.Clone->getName() << "\n"); Call->setCalledFunction(S.Clone); + Solver.resetLatticeValueFor(Call); } Clones.push_back(S.Clone); @@ -364,6 +371,10 @@ updateCallSites(F, AllSpecs.begin() + Begin, AllSpecs.begin() + End); } + // Updating the callsites may have changed their lattice state. + // Run the solver to notify the users of the modified callsites. + Solver.solve(); + promoteConstantStackValues(); return true; } @@ -533,13 +544,8 @@ Function *FunctionSpecializer::createSpecialization(Function *F, const SpecSig &S) { Function *Clone = cloneCandidateFunction(F); - // Initialize the lattice state of the arguments of the function clone, - // marking the argument on which we specialized the function constant - // with the given value. - Solver.markArgInFuncSpecialization(Clone, S.Args); - Solver.addArgumentTrackedFunction(Clone); - Solver.markBlockExecutable(&Clone->front()); + Solver.addTrackedFunction(Clone); // Mark all the specialized functions Specializations.insert(Clone); @@ -768,6 +774,7 @@ LLVM_DEBUG(dbgs() << "FnSpecialization: Redirecting " << *CS << " to call " << BestSpec->Clone->getName() << "\n"); CS->setCalledFunction(BestSpec->Clone); + Solver.visit(CS); ShouldDecrementCount = true; } diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -373,6 +373,10 @@ MapVector, ValueLatticeElement> TrackedMultipleRetVals; + /// The set of values whose lattice has been invalidated. + /// Populated by resetLatticeValueFor(), cleared each time we solve. + DenseSet Invalidated; + /// MRVFunctionsTracked - Each function in TrackedMultipleRetVals is /// represented here for efficient lookup. SmallPtrSet MRVFunctionsTracked; @@ -498,6 +502,59 @@ return LV; } + /// Traverse the use-def chain of \p V, marking itself and its users as + /// "unknown" on the way. + void invalidate(Value *V) { + auto [_, Inserted] = Invalidated.insert(V); + if (!Inserted) + return; + + bool Found = false; + // For return instructions we need to invalidate the tracked returns map. + // Anything else has its lattice in the value map. + if (auto *RetInst = dyn_cast(V)) { + Function *F = RetInst->getParent()->getParent(); + auto TFRVI = TrackedRetVals.find(F); + if (TFRVI != TrackedRetVals.end()) { + TFRVI->second = ValueLatticeElement(); + Found = true; + } else if (MRVFunctionsTracked.count(F)) { + auto *STy = cast(F->getReturnType()); + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) + TrackedMultipleRetVals[std::make_pair(F, I)] = ValueLatticeElement(); + Found = true; + } + // Next we need to invalidate the users of the function. + V = F; + } else if (auto *STy = dyn_cast(V->getType())) { + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) { + auto It = StructValueState.find(std::make_pair(V, I)); + if (It != StructValueState.end()) { + It->second = ValueLatticeElement(); + Found = true; + } + } + } else { + auto It = ValueState.find(V); + if (It != ValueState.end()) { + It->second = ValueLatticeElement(); + Found = true; + } + } + + if (Found) { + LLVM_DEBUG(dbgs() << "Invalidated lattice for " << *V << "\n"); + + for (User *U : V->users()) + invalidate(U); + + auto It = AdditionalUsers.find(V); + if (It != AdditionalUsers.end()) + for (User *U : It->second) + invalidate(U); + } + } + /// markEdgeExecutable - Mark a basic block as executable, adding it to the BB /// work list if it is not already executable. bool markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest); @@ -700,6 +757,19 @@ void removeLatticeValueFor(Value *V) { ValueState.erase(V); } + /// Invalidate the Lattive Value of \p Call and its users after specializing + /// the call. Then recompute it. + void resetLatticeValueFor(CallBase *Call) { + // Calls to void returning functions do not need invalidation. + Function *F = Call->getCalledFunction(); + if (!F->getReturnType()->isVoidTy()) { + assert((TrackedRetVals.count(F) || MRVFunctionsTracked.count(F)) && + "All non void specializations should be tracked"); + invalidate(Call); + } + visitCallBase(*Call); + } + const ValueLatticeElement &getLatticeValueFor(Value *V) const { assert(!V->getType()->isStructTy() && "Should use getStructLatticeValueFor"); @@ -738,9 +808,6 @@ return TrackingIncomingArguments; } - void markArgInFuncSpecialization(Function *F, - const SmallVectorImpl &Args); - void markFunctionUnreachable(Function *F) { for (auto &BB : *F) BBExecutable.erase(&BB); @@ -833,40 +900,6 @@ return nullptr; } -void SCCPInstVisitor::markArgInFuncSpecialization( - Function *F, const SmallVectorImpl &Args) { - assert(!Args.empty() && "Specialization without arguments"); - assert(F->arg_size() == Args[0].Formal->getParent()->arg_size() && - "Functions should have the same number of arguments"); - - auto Iter = Args.begin(); - Argument *NewArg = F->arg_begin(); - Argument *OldArg = Args[0].Formal->getParent()->arg_begin(); - for (auto End = F->arg_end(); NewArg != End; ++NewArg, ++OldArg) { - - LLVM_DEBUG(dbgs() << "SCCP: Marking argument " - << NewArg->getNameOrAsOperand() << "\n"); - - if (Iter != Args.end() && OldArg == Iter->Formal) { - // Mark the argument constants in the new function. - markConstant(NewArg, Iter->Actual); - ++Iter; - } else if (ValueState.count(OldArg)) { - // For the remaining arguments in the new function, copy the lattice state - // over from the old function. - // - // Note: This previously looked like this: - // ValueState[NewArg] = ValueState[OldArg]; - // This is incorrect because the DenseMap class may resize the underlying - // memory when inserting `NewArg`, which will invalidate the reference to - // `OldArg`. Instead, we make sure `NewArg` exists before setting it. - auto &NewValue = ValueState[NewArg]; - NewValue = ValueState[OldArg]; - pushToWorkList(NewValue, NewArg); - } - } -} - void SCCPInstVisitor::visitInstruction(Instruction &I) { // All the instructions we don't do any special handling for just // go to overdefined. @@ -1700,6 +1733,8 @@ } void SCCPInstVisitor::solve() { + // Empty the set of values whose lattice has been invalidated. + Invalidated.clear(); // Process the work lists until they are empty! while (!BBWorkList.empty() || !InstWorkList.empty() || !OverdefinedInstWorkList.empty()) { @@ -1917,6 +1952,10 @@ return Visitor->removeLatticeValueFor(V); } +void SCCPSolver::resetLatticeValueFor(CallBase *Call) { + Visitor->resetLatticeValueFor(Call); +} + const ValueLatticeElement &SCCPSolver::getLatticeValueFor(Value *V) const { return Visitor->getLatticeValueFor(V); } @@ -1949,11 +1988,6 @@ return Visitor->getArgumentTrackedFunctions(); } -void SCCPSolver::markArgInFuncSpecialization( - Function *F, const SmallVectorImpl &Args) { - Visitor->markArgInFuncSpecialization(F, Args); -} - void SCCPSolver::markFunctionUnreachable(Function *F) { Visitor->markFunctionUnreachable(F); } diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll --- a/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll @@ -36,7 +36,7 @@ ; CHECK-NEXT: [[TMP1:%.*]] = call i64 @func2.1(ptr getelementptr inbounds ([[STRUCT]], ptr @Global, i32 0, i32 4)) ; CHECK-NEXT: br label [[MERGE]] ; CHECK: merge: -; CHECK-NEXT: [[TMP2:%.*]] = phi i64 [ [[TMP0]], [[PLUS]] ], [ [[TMP1]], [[MINUS]] ] +; CHECK-NEXT: [[TMP2:%.*]] = phi i64 [ ptrtoint (ptr getelementptr inbounds ([[STRUCT:%.*]], ptr @Global, i32 0, i32 3) to i64), [[PLUS]] ], [ ptrtoint (ptr getelementptr inbounds ([[STRUCT:%.*]], ptr @Global, i32 0, i32 4) to i64), [[MINUS]] ] ; CHECK-NEXT: ret i64 [[TMP2]] ; entry: @@ -70,3 +70,4 @@ %3 = add i64 %1, %2 ret i64 %3 } + diff --git a/llvm/test/Transforms/FunctionSpecialization/specialize-multiple-arguments.ll b/llvm/test/Transforms/FunctionSpecialization/specialize-multiple-arguments.ll --- a/llvm/test/Transforms/FunctionSpecialization/specialize-multiple-arguments.ll +++ b/llvm/test/Transforms/FunctionSpecialization/specialize-multiple-arguments.ll @@ -116,11 +116,11 @@ ; ; THREE-LABEL: define internal i64 @compute.3(i64 %x, i64 %y, ptr %binop1, ptr %binop2) { ; THREE-NEXT: entry: -; THREE-NEXT: [[TMP0:%.+]] = call i64 @minus(i64 %x, i64 %y) -; THREE-NEXT: [[TMP1:%.+]] = call i64 @power(i64 %x, i64 %y) +; THREE-NEXT: [[TMP0:%.+]] = call i64 @minus(i64 %x, i64 42) +; THREE-NEXT: [[TMP1:%.+]] = call i64 @power(i64 %x, i64 42) ; THREE-NEXT: [[TMP2:%.+]] = add i64 [[TMP0]], [[TMP1]] ; THREE-NEXT: [[TMP3:%.+]] = sdiv i64 [[TMP2]], %x -; THREE-NEXT: [[TMP4:%.+]] = sub i64 [[TMP3]], %y +; THREE-NEXT: [[TMP4:%.+]] = sub i64 [[TMP3]], 42 ; THREE-NEXT: [[TMP5:%.+]] = mul i64 [[TMP4]], 2 ; THREE-NEXT: ret i64 [[TMP5]] ; THREE-NEXT: } diff --git a/llvm/test/Transforms/FunctionSpecialization/track-return.ll b/llvm/test/Transforms/FunctionSpecialization/track-return.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/track-return.ll @@ -0,0 +1,106 @@ +; RUN: opt -passes="ipsccp" -force-specialization -funcspec-for-literal-constant -funcspec-max-iters=3 -S < %s | FileCheck %s + +define i64 @main() { +; CHECK: define i64 @main +; CHECK-NEXT: entry: +; CHECK-NEXT: [[C1:%.*]] = call i64 @foo.1(i1 true, i64 3, i64 1) +; CHECK-NEXT: [[C2:%.*]] = call i64 @foo.2(i1 false, i64 4, i64 -1) +; CHECK-NEXT: ret i64 8 +; +entry: + %c1 = call i64 @foo(i1 true, i64 3, i64 1) + %c2 = call i64 @foo(i1 false, i64 4, i64 -1) + %add = add i64 %c1, %c2 + ret i64 %add +} + +define internal i64 @foo(i1 %flag, i64 %m, i64 %n) { +; +; CHECK: define internal i64 @foo.1 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %plus +; CHECK: plus: +; CHECK-NEXT: [[N0:%.*]] = call i64 @binop.4(i64 3, i64 1) +; CHECK-NEXT: [[RES0:%.*]] = call i64 @bar.6(i64 4) +; CHECK-NEXT: br label %merge +; CHECK: merge: +; CHECK-NEXT: ret i64 undef +; +; CHECK: define internal i64 @foo.2 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %minus +; CHECK: minus: +; CHECK-NEXT: [[N1:%.*]] = call i64 @binop.3(i64 4, i64 -1) +; CHECK-NEXT: [[RES1:%.*]] = call i64 @bar.5(i64 3) +; CHECK-NEXT: br label %merge +; CHECK: merge: +; CHECK-NEXT: ret i64 undef +; +entry: + br i1 %flag, label %plus, label %minus + +plus: + %n0 = call i64 @binop(i64 %m, i64 %n) + %res0 = call i64 @bar(i64 %n0) + br label %merge + +minus: + %n1 = call i64 @binop(i64 %m, i64 %n) + %res1 = call i64 @bar(i64 %n1) + br label %merge + +merge: + %res = phi i64 [ %res0, %plus ], [ %res1, %minus] + ret i64 %res +} + +define internal i64 @binop(i64 %x, i64 %y) { +; +; CHECK: define internal i64 @binop.3 +; CHECK-NEXT: entry: +; CHECK-NEXT: ret i64 undef +; +; CHECK: define internal i64 @binop.4 +; CHECK-NEXT: entry: +; CHECK-NEXT: ret i64 undef +; +entry: + %z = add i64 %x, %y + ret i64 %z +} + +define internal i64 @bar(i64 %n) { +; +; CHECK: define internal i64 @bar.5 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %if.else +; CHECK: if.else: +; CHECK-NEXT: br label %if.end +; CHECK: if.end: +; CHECK-NEXT: ret i64 undef +; +; CHECK: define internal i64 @bar.6 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %if.then +; CHECK: if.then: +; CHECK-NEXT: br label %if.end +; CHECK: if.end: +; CHECK-NEXT: ret i64 undef +; +entry: + %cmp = icmp sgt i64 %n, 3 + br i1 %cmp, label %if.then, label %if.else + +if.then: + %res0 = sdiv i64 %n, 2 + br label %if.end + +if.else: + %res1 = mul i64 %n, 2 + br label %if.end + +if.end: + %res = phi i64 [ %res0, %if.then ], [ %res1, %if.else] + ret i64 %res +} +