diff --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h --- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h +++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h @@ -132,6 +132,8 @@ void solveWhileResolvedUndefsIn(SmallVectorImpl &WorkList); + void solveWhileResolvedUndefs(); + bool isBlockExecutable(BasicBlock *BB) const; // isEdgeFeasible - Return true if the control flow edge from the 'From' basic @@ -142,6 +144,10 @@ void removeLatticeValueFor(Value *V); + /// Invalidate the Lattice Value of \p Call and its users after specializing + /// the call. Then recompute it. + void resetLatticeValueFor(CallBase *Call); + const ValueLatticeElement &getLatticeValueFor(Value *V) const; /// getTrackedRetVals - Get the inferred return value map. diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -332,6 +332,33 @@ updateCallSites(F, AllSpecs.begin() + Begin, AllSpecs.begin() + End); } + for (Function *F : Clones) { + if (F->getReturnType()->isVoidTy()) + continue; + if (F->getReturnType()->isStructTy()) { + StructType *STy = cast(F->getReturnType()); + if (!Solver.isStructLatticeConstant(F, STy)) + continue; + } else { + auto It = Solver.getTrackedRetVals().find(F); + assert(It != Solver.getTrackedRetVals().end() && + "Return value ought to be tracked"); + if (SCCPSolver::isOverdefined(It->second)) + continue; + } + for (User *U : F->users()) { + if (auto *CS = dyn_cast(U)) { + //The user instruction does not call our function. + if (CS->getCalledFunction() != F) + continue; + Solver.resetLatticeValueFor(CS); + } + } + } + + // Rerun the solver to notify the users of the modified callsites. + Solver.solveWhileResolvedUndefs(); + promoteConstantStackValues(); return true; } @@ -514,6 +541,7 @@ Function *Clone = cloneCandidateFunction(F); Solver.addArgumentTrackedFunction(Clone); + Solver.addTrackedFunction(Clone); // Mark all the specialized functions Specializations.insert(Clone); diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -352,6 +352,10 @@ MapVector, ValueLatticeElement> TrackedMultipleRetVals; + /// The set of values whose lattice has been invalidated. + /// Populated by resetLatticeValueFor(), cleared after resolving undefs. + DenseSet Invalidated; + /// MRVFunctionsTracked - Each function in TrackedMultipleRetVals is /// represented here for efficient lookup. SmallPtrSet MRVFunctionsTracked; @@ -477,6 +481,70 @@ return LV; } + /// Traverse the use-def chain of \p V, marking itself and its users as + /// "unknown" on the way. + void invalidate(Instruction *I) { + SmallVector ToInvalidate; + ToInvalidate.push_back(I); + + while (!ToInvalidate.empty()) { + Instruction *Inst = ToInvalidate.pop_back_val(); + + auto [_, Inserted] = Invalidated.insert(Inst); + if (!Inserted) + continue; + + if (!BBExecutable.count(Inst->getParent())) + continue; + + Value *V = nullptr; + // For return instructions we need to invalidate the tracked returns map. + // Anything else has its lattice in the value map. + if (auto *RetInst = dyn_cast(Inst)) { + Function *F = RetInst->getParent()->getParent(); + auto TFRVI = TrackedRetVals.find(F); + if (TFRVI != TrackedRetVals.end()) { + TFRVI->second = ValueLatticeElement(); + V = F; + } else if (MRVFunctionsTracked.count(F)) { + auto *STy = cast(F->getReturnType()); + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) + TrackedMultipleRetVals[std::make_pair(F, I)] = + ValueLatticeElement(); + V = F; + } + } else if (auto *STy = dyn_cast(Inst->getType())) { + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) { + auto It = StructValueState.find(std::make_pair(Inst, I)); + if (It != StructValueState.end()) { + It->second = ValueLatticeElement(); + V = Inst; + } + } + } else { + auto It = ValueState.find(Inst); + if (It != ValueState.end()) { + It->second = ValueLatticeElement(); + V = Inst; + } + } + + if (V) { + LLVM_DEBUG(dbgs() << "Invalidated lattice for " << *V << "\n"); + + for (User *U : V->users()) + if (auto *UI = dyn_cast(U)) + ToInvalidate.push_back(UI); + + auto It = AdditionalUsers.find(V); + if (It != AdditionalUsers.end()) + for (User *U : It->second) + if (auto *UI = dyn_cast(U)) + ToInvalidate.push_back(UI); + } + } + } + /// markEdgeExecutable - Mark a basic block as executable, adding it to the BB /// work list if it is not already executable. bool markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest); @@ -656,6 +724,8 @@ void solve(); + bool resolvedUndef(Instruction &I); + bool resolvedUndefsIn(Function &F); bool isBlockExecutable(BasicBlock *BB) const { @@ -678,6 +748,19 @@ void removeLatticeValueFor(Value *V) { ValueState.erase(V); } + /// Invalidate the Lattice Value of \p Call and its users after specializing + /// the call. Then recompute it. + void resetLatticeValueFor(CallBase *Call) { + // Calls to void returning functions do not need invalidation. + Function *F = Call->getCalledFunction(); + (void)F; + assert(!F->getReturnType()->isVoidTy() && + (TrackedRetVals.count(F) || MRVFunctionsTracked.count(F)) && + "All non void specializations should be tracked"); + invalidate(Call); + handleCallResult(*Call); + } + const ValueLatticeElement &getLatticeValueFor(Value *V) const { assert(!V->getType()->isStructTy() && "Should use getStructLatticeValueFor"); @@ -744,6 +827,18 @@ ResolvedUndefs |= resolvedUndefsIn(*F); } } + + void solveWhileResolvedUndefs() { + bool ResolvedUndefs = true; + while (ResolvedUndefs) { + solve(); + ResolvedUndefs = false; + for (Value *V : Invalidated) + if (auto *I = dyn_cast(V)) + ResolvedUndefs |= resolvedUndef(*I); + } + Invalidated.clear(); + } }; } // namespace llvm @@ -1678,6 +1773,7 @@ // things to overdefined more quickly. while (!OverdefinedInstWorkList.empty()) { Value *I = OverdefinedInstWorkList.pop_back_val(); + Invalidated.erase(I); LLVM_DEBUG(dbgs() << "\nPopped off OI-WL: " << *I << '\n'); @@ -1694,6 +1790,7 @@ // Process the instruction work list. while (!InstWorkList.empty()) { Value *I = InstWorkList.pop_back_val(); + Invalidated.erase(I); LLVM_DEBUG(dbgs() << "\nPopped off I-WL: " << *I << '\n'); @@ -1721,6 +1818,61 @@ } } +bool SCCPInstVisitor::resolvedUndef(Instruction &I) { + // Look for instructions which produce undef values. + if (I.getType()->isVoidTy()) + return false; + + if (auto *STy = dyn_cast(I.getType())) { + // Only a few things that can be structs matter for undef. + + // Tracked calls must never be marked overdefined in resolvedUndefsIn. + if (auto *CB = dyn_cast(&I)) + if (Function *F = CB->getCalledFunction()) + if (MRVFunctionsTracked.count(F)) + return false; + + // extractvalue and insertvalue don't need to be marked; they are + // tracked as precisely as their operands. + if (isa(I) || isa(I)) + return false; + // Send the results of everything else to overdefined. We could be + // more precise than this but it isn't worth bothering. + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + ValueLatticeElement &LV = getStructValueState(&I, i); + if (LV.isUnknown()) { + markOverdefined(LV, &I); + return true; + } + } + return false; + } + + ValueLatticeElement &LV = getValueState(&I); + if (!LV.isUnknown()) + return false; + + // There are two reasons a call can have an undef result + // 1. It could be tracked. + // 2. It could be constant-foldable. + // Because of the way we solve return values, tracked calls must + // never be marked overdefined in resolvedUndefsIn. + if (auto *CB = dyn_cast(&I)) + if (Function *F = CB->getCalledFunction()) + if (TrackedRetVals.count(F)) + return false; + + if (isa(I)) { + // A load here means one of two things: a load of undef from a global, + // a load from an unknown pointer. Either way, having it return undef + // is okay. + return false; + } + + markOverdefined(&I); + return true; +} + /// While solving the dataflow for a function, we don't compute a result for /// operations with an undef operand, to allow undef to be lowered to a /// constant later. For example, constant folding of "zext i8 undef to i16" @@ -1740,60 +1892,8 @@ if (!BBExecutable.count(&BB)) continue; - for (Instruction &I : BB) { - // Look for instructions which produce undef values. - if (I.getType()->isVoidTy()) - continue; - - if (auto *STy = dyn_cast(I.getType())) { - // Only a few things that can be structs matter for undef. - - // Tracked calls must never be marked overdefined in resolvedUndefsIn. - if (auto *CB = dyn_cast(&I)) - if (Function *F = CB->getCalledFunction()) - if (MRVFunctionsTracked.count(F)) - continue; - - // extractvalue and insertvalue don't need to be marked; they are - // tracked as precisely as their operands. - if (isa(I) || isa(I)) - continue; - // Send the results of everything else to overdefined. We could be - // more precise than this but it isn't worth bothering. - for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { - ValueLatticeElement &LV = getStructValueState(&I, i); - if (LV.isUnknown()) { - markOverdefined(LV, &I); - MadeChange = true; - } - } - continue; - } - - ValueLatticeElement &LV = getValueState(&I); - if (!LV.isUnknown()) - continue; - - // There are two reasons a call can have an undef result - // 1. It could be tracked. - // 2. It could be constant-foldable. - // Because of the way we solve return values, tracked calls must - // never be marked overdefined in resolvedUndefsIn. - if (auto *CB = dyn_cast(&I)) - if (Function *F = CB->getCalledFunction()) - if (TrackedRetVals.count(F)) - continue; - - if (isa(I)) { - // A load here means one of two things: a load of undef from a global, - // a load from an unknown pointer. Either way, having it return undef - // is okay. - continue; - } - - markOverdefined(&I); - MadeChange = true; - } + for (Instruction &I : BB) + MadeChange |= resolvedUndef(I); } LLVM_DEBUG(if (MadeChange) dbgs() @@ -1871,6 +1971,10 @@ Visitor->solveWhileResolvedUndefsIn(WorkList); } +void SCCPSolver::solveWhileResolvedUndefs() { + Visitor->solveWhileResolvedUndefs(); +} + bool SCCPSolver::isBlockExecutable(BasicBlock *BB) const { return Visitor->isBlockExecutable(BB); } @@ -1888,6 +1992,10 @@ return Visitor->removeLatticeValueFor(V); } +void SCCPSolver::resetLatticeValueFor(CallBase *Call) { + Visitor->resetLatticeValueFor(Call); +} + const ValueLatticeElement &SCCPSolver::getLatticeValueFor(Value *V) const { return Visitor->getLatticeValueFor(V); } diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll --- a/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll @@ -36,7 +36,7 @@ ; CHECK-NEXT: [[TMP1:%.*]] = call i64 @func2.1(ptr getelementptr inbounds ([[STRUCT]], ptr @Global, i32 0, i32 4)) ; CHECK-NEXT: br label [[MERGE]] ; CHECK: merge: -; CHECK-NEXT: [[TMP2:%.*]] = phi i64 [ [[TMP0]], [[PLUS]] ], [ [[TMP1]], [[MINUS]] ] +; CHECK-NEXT: [[TMP2:%.*]] = phi i64 [ ptrtoint (ptr getelementptr inbounds ([[STRUCT:%.*]], ptr @Global, i32 0, i32 3) to i64), [[PLUS]] ], [ ptrtoint (ptr getelementptr inbounds ([[STRUCT:%.*]], ptr @Global, i32 0, i32 4) to i64), [[MINUS]] ] ; CHECK-NEXT: ret i64 [[TMP2]] ; entry: @@ -70,3 +70,4 @@ %3 = add i64 %1, %2 ret i64 %3 } + diff --git a/llvm/test/Transforms/FunctionSpecialization/track-return.ll b/llvm/test/Transforms/FunctionSpecialization/track-return.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/track-return.ll @@ -0,0 +1,106 @@ +; RUN: opt -passes="ipsccp" -force-specialization -funcspec-for-literal-constant -funcspec-max-iters=3 -S < %s | FileCheck %s + +define i64 @main() { +; CHECK: define i64 @main +; CHECK-NEXT: entry: +; CHECK-NEXT: [[C1:%.*]] = call i64 @foo.1(i1 true, i64 3, i64 1) +; CHECK-NEXT: [[C2:%.*]] = call i64 @foo.2(i1 false, i64 4, i64 -1) +; CHECK-NEXT: ret i64 8 +; +entry: + %c1 = call i64 @foo(i1 true, i64 3, i64 1) + %c2 = call i64 @foo(i1 false, i64 4, i64 -1) + %add = add i64 %c1, %c2 + ret i64 %add +} + +define internal i64 @foo(i1 %flag, i64 %m, i64 %n) { +; +; CHECK: define internal i64 @foo.1 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %plus +; CHECK: plus: +; CHECK-NEXT: [[N0:%.*]] = call i64 @binop.4(i64 3, i64 1) +; CHECK-NEXT: [[RES0:%.*]] = call i64 @bar.6(i64 4) +; CHECK-NEXT: br label %merge +; CHECK: merge: +; CHECK-NEXT: ret i64 undef +; +; CHECK: define internal i64 @foo.2 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %minus +; CHECK: minus: +; CHECK-NEXT: [[N1:%.*]] = call i64 @binop.3(i64 4, i64 -1) +; CHECK-NEXT: [[RES1:%.*]] = call i64 @bar.5(i64 3) +; CHECK-NEXT: br label %merge +; CHECK: merge: +; CHECK-NEXT: ret i64 undef +; +entry: + br i1 %flag, label %plus, label %minus + +plus: + %n0 = call i64 @binop(i64 %m, i64 %n) + %res0 = call i64 @bar(i64 %n0) + br label %merge + +minus: + %n1 = call i64 @binop(i64 %m, i64 %n) + %res1 = call i64 @bar(i64 %n1) + br label %merge + +merge: + %res = phi i64 [ %res0, %plus ], [ %res1, %minus] + ret i64 %res +} + +define internal i64 @binop(i64 %x, i64 %y) { +; +; CHECK: define internal i64 @binop.3 +; CHECK-NEXT: entry: +; CHECK-NEXT: ret i64 undef +; +; CHECK: define internal i64 @binop.4 +; CHECK-NEXT: entry: +; CHECK-NEXT: ret i64 undef +; +entry: + %z = add i64 %x, %y + ret i64 %z +} + +define internal i64 @bar(i64 %n) { +; +; CHECK: define internal i64 @bar.5 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %if.else +; CHECK: if.else: +; CHECK-NEXT: br label %if.end +; CHECK: if.end: +; CHECK-NEXT: ret i64 undef +; +; CHECK: define internal i64 @bar.6 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %if.then +; CHECK: if.then: +; CHECK-NEXT: br label %if.end +; CHECK: if.end: +; CHECK-NEXT: ret i64 undef +; +entry: + %cmp = icmp sgt i64 %n, 3 + br i1 %cmp, label %if.then, label %if.else + +if.then: + %res0 = sdiv i64 %n, 2 + br label %if.end + +if.else: + %res1 = mul i64 %n, 2 + br label %if.end + +if.end: + %res = phi i64 [ %res0, %if.then ], [ %res1, %if.else] + ret i64 %res +} +