diff --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h --- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h +++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h @@ -142,6 +142,10 @@ void removeLatticeValueFor(Value *V); + /// Invalidate the Lattice Value of \p Call and its users after specializing + /// the call. Then recompute it. + void resetLatticeValueFor(CallBase *Call); + const ValueLatticeElement &getLatticeValueFor(Value *V) const; /// getTrackedRetVals - Get the inferred return value map. diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -371,6 +371,33 @@ updateCallSites(F, AllSpecs.begin() + Begin, AllSpecs.begin() + End); } + for (Function *F : Clones) { + if (F->getReturnType()->isVoidTy()) + continue; + if (F->getReturnType()->isStructTy()) { + StructType *STy = cast(F->getReturnType()); + if (!Solver.isStructLatticeConstant(F, STy)) + continue; + } else { + auto It = Solver.getTrackedRetVals().find(F); + assert(It != Solver.getTrackedRetVals().end() && + "Return value ought to be tracked"); + if (SCCPSolver::isOverdefined(It->second)) + continue; + } + for (User *U : F->users()) { + if (auto *CS = dyn_cast(U)) { + //The user instruction does not call our function. + if (CS->getCalledFunction() != F) + continue; + Solver.resetLatticeValueFor(CS); + } + } + } + + // Rerun the solver to notify the users of the modified callsites. + Solver.solve(); + promoteConstantStackValues(); return true; } @@ -541,6 +568,7 @@ Function *Clone = cloneCandidateFunction(F); Solver.addArgumentTrackedFunction(Clone); + Solver.addTrackedFunction(Clone); // Mark all the specialized functions Specializations.insert(Clone); diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -373,6 +373,10 @@ MapVector, ValueLatticeElement> TrackedMultipleRetVals; + /// The set of values whose lattice has been invalidated. + /// Populated by resetLatticeValueFor(), cleared before solving. + DenseSet Invalidated; + /// MRVFunctionsTracked - Each function in TrackedMultipleRetVals is /// represented here for efficient lookup. SmallPtrSet MRVFunctionsTracked; @@ -498,6 +502,62 @@ return LV; } + /// Traverse the use-def chain of \p V, marking itself and its users as + /// "unknown" on the way. + void invalidate(Instruction *Inst) { + auto [_, Inserted] = Invalidated.insert(Inst); + if (!Inserted) + return; + + if (!BBExecutable.count(Inst->getParent())) + return; + + Value *V = nullptr; + // For return instructions we need to invalidate the tracked returns map. + // Anything else has its lattice in the value map. + if (auto *RetInst = dyn_cast(Inst)) { + Function *F = RetInst->getParent()->getParent(); + auto TFRVI = TrackedRetVals.find(F); + if (TFRVI != TrackedRetVals.end()) { + TFRVI->second = ValueLatticeElement(); + V = F; + } else if (MRVFunctionsTracked.count(F)) { + auto *STy = cast(F->getReturnType()); + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) + TrackedMultipleRetVals[std::make_pair(F, I)] = ValueLatticeElement(); + V = F; + } + } else if (auto *STy = dyn_cast(Inst->getType())) { + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) { + auto It = StructValueState.find(std::make_pair(Inst, I)); + if (It != StructValueState.end()) { + It->second = ValueLatticeElement(); + V = Inst; + } + } + } else { + auto It = ValueState.find(Inst); + if (It != ValueState.end()) { + It->second = ValueLatticeElement(); + V = Inst; + } + } + + if (V) { + LLVM_DEBUG(dbgs() << "Invalidated lattice for " << *V << "\n"); + + for (User *U : V->users()) + if (auto *UI = dyn_cast(U)) + invalidate(UI); + + auto It = AdditionalUsers.find(V); + if (It != AdditionalUsers.end()) + for (User *U : It->second) + if (auto *UI = dyn_cast(U)) + invalidate(UI); + } + } + /// markEdgeExecutable - Mark a basic block as executable, adding it to the BB /// work list if it is not already executable. bool markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest); @@ -699,6 +759,18 @@ void removeLatticeValueFor(Value *V) { ValueState.erase(V); } + /// Invalidate the Lattice Value of \p Call and its users after specializing + /// the call. Then recompute it. + void resetLatticeValueFor(CallBase *Call) { + // Calls to void returning functions do not need invalidation. + Function *F = Call->getCalledFunction(); + assert(!F->getReturnType()->isVoidTy() && + (TrackedRetVals.count(F) || MRVFunctionsTracked.count(F)) && + "All non void specializations should be tracked"); + invalidate(Call); + handleCallResult(*Call); + } + const ValueLatticeElement &getLatticeValueFor(Value *V) const { assert(!V->getType()->isStructTy() && "Should use getStructLatticeValueFor"); @@ -1664,6 +1736,9 @@ } void SCCPInstVisitor::solve() { + // Empty the set of values whose lattice has been invalidated. + Invalidated.clear(); + // Process the work lists until they are empty! while (!BBWorkList.empty() || !InstWorkList.empty() || !OverdefinedInstWorkList.empty()) { @@ -1881,6 +1956,10 @@ return Visitor->removeLatticeValueFor(V); } +void SCCPSolver::resetLatticeValueFor(CallBase *Call) { + Visitor->resetLatticeValueFor(Call); +} + const ValueLatticeElement &SCCPSolver::getLatticeValueFor(Value *V) const { return Visitor->getLatticeValueFor(V); } diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll --- a/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll @@ -36,7 +36,7 @@ ; CHECK-NEXT: [[TMP1:%.*]] = call i64 @func2.1(ptr getelementptr inbounds ([[STRUCT]], ptr @Global, i32 0, i32 4)) ; CHECK-NEXT: br label [[MERGE]] ; CHECK: merge: -; CHECK-NEXT: [[TMP2:%.*]] = phi i64 [ [[TMP0]], [[PLUS]] ], [ [[TMP1]], [[MINUS]] ] +; CHECK-NEXT: [[TMP2:%.*]] = phi i64 [ ptrtoint (ptr getelementptr inbounds ([[STRUCT:%.*]], ptr @Global, i32 0, i32 3) to i64), [[PLUS]] ], [ ptrtoint (ptr getelementptr inbounds ([[STRUCT:%.*]], ptr @Global, i32 0, i32 4) to i64), [[MINUS]] ] ; CHECK-NEXT: ret i64 [[TMP2]] ; entry: @@ -70,3 +70,4 @@ %3 = add i64 %1, %2 ret i64 %3 } + diff --git a/llvm/test/Transforms/FunctionSpecialization/track-return.ll b/llvm/test/Transforms/FunctionSpecialization/track-return.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/track-return.ll @@ -0,0 +1,106 @@ +; RUN: opt -passes="ipsccp" -force-specialization -funcspec-for-literal-constant -funcspec-max-iters=3 -S < %s | FileCheck %s + +define i64 @main() { +; CHECK: define i64 @main +; CHECK-NEXT: entry: +; CHECK-NEXT: [[C1:%.*]] = call i64 @foo.1(i1 true, i64 3, i64 1) +; CHECK-NEXT: [[C2:%.*]] = call i64 @foo.2(i1 false, i64 4, i64 -1) +; CHECK-NEXT: ret i64 8 +; +entry: + %c1 = call i64 @foo(i1 true, i64 3, i64 1) + %c2 = call i64 @foo(i1 false, i64 4, i64 -1) + %add = add i64 %c1, %c2 + ret i64 %add +} + +define internal i64 @foo(i1 %flag, i64 %m, i64 %n) { +; +; CHECK: define internal i64 @foo.1 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %plus +; CHECK: plus: +; CHECK-NEXT: [[N0:%.*]] = call i64 @binop.4(i64 3, i64 1) +; CHECK-NEXT: [[RES0:%.*]] = call i64 @bar.6(i64 4) +; CHECK-NEXT: br label %merge +; CHECK: merge: +; CHECK-NEXT: ret i64 undef +; +; CHECK: define internal i64 @foo.2 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %minus +; CHECK: minus: +; CHECK-NEXT: [[N1:%.*]] = call i64 @binop.3(i64 4, i64 -1) +; CHECK-NEXT: [[RES1:%.*]] = call i64 @bar.5(i64 3) +; CHECK-NEXT: br label %merge +; CHECK: merge: +; CHECK-NEXT: ret i64 undef +; +entry: + br i1 %flag, label %plus, label %minus + +plus: + %n0 = call i64 @binop(i64 %m, i64 %n) + %res0 = call i64 @bar(i64 %n0) + br label %merge + +minus: + %n1 = call i64 @binop(i64 %m, i64 %n) + %res1 = call i64 @bar(i64 %n1) + br label %merge + +merge: + %res = phi i64 [ %res0, %plus ], [ %res1, %minus] + ret i64 %res +} + +define internal i64 @binop(i64 %x, i64 %y) { +; +; CHECK: define internal i64 @binop.3 +; CHECK-NEXT: entry: +; CHECK-NEXT: ret i64 undef +; +; CHECK: define internal i64 @binop.4 +; CHECK-NEXT: entry: +; CHECK-NEXT: ret i64 undef +; +entry: + %z = add i64 %x, %y + ret i64 %z +} + +define internal i64 @bar(i64 %n) { +; +; CHECK: define internal i64 @bar.5 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %if.else +; CHECK: if.else: +; CHECK-NEXT: br label %if.end +; CHECK: if.end: +; CHECK-NEXT: ret i64 undef +; +; CHECK: define internal i64 @bar.6 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %if.then +; CHECK: if.then: +; CHECK-NEXT: br label %if.end +; CHECK: if.end: +; CHECK-NEXT: ret i64 undef +; +entry: + %cmp = icmp sgt i64 %n, 3 + br i1 %cmp, label %if.then, label %if.else + +if.then: + %res0 = sdiv i64 %n, 2 + br label %if.end + +if.else: + %res1 = mul i64 %n, 2 + br label %if.end + +if.end: + %res = phi i64 [ %res0, %if.then ], [ %res1, %if.else] + ret i64 %res +} +