Index: lib/Transforms/Scalar/SCCP.cpp =================================================================== --- lib/Transforms/Scalar/SCCP.cpp +++ lib/Transforms/Scalar/SCCP.cpp @@ -223,6 +223,10 @@ /// represented here for efficient lookup. SmallPtrSet MRVFunctionsTracked; + /// MustTailFunctions - Each function here is a callee of non-removable + /// musttail call site. + SmallPtrSet MustTailCallees; + /// TrackingIncomingArguments - This is the set of functions for whose /// arguments we make optimistic assumptions about and try to prove as /// constants. @@ -289,6 +293,18 @@ TrackedRetVals.insert(std::make_pair(F, LatticeVal())); } + /// AddMustTailCallee - If the SCCP solver finds that this function is called + /// from non-removable musttail call site. + void AddMustTailCallee(Function *F) { + MustTailCallees.insert(F); + } + + /// Returns true if the given function is called from non-removable musttail + /// call site. + bool isMustTailCallee(Function *F) { + return MustTailCallees.count(F); + } + void AddArgumentTrackedFunction(Function *F) { TrackingIncomingArguments.insert(F); } @@ -358,6 +374,12 @@ return MRVFunctionsTracked; } + /// getMustTailCallees - Get the set of functions which are called + /// from non-removable musttail call sites. + const SmallPtrSet getMustTailCallees() { + return MustTailCallees; + } + /// markOverdefined - Mark the specified value overdefined. This /// works with both scalars and structs. void markOverdefined(Value *V) { @@ -1672,6 +1694,23 @@ IV.isConstant() ? IV.getConstant() : UndefValue::get(V->getType()); } assert(Const && "Constant is nullptr here!"); + + // Replacing `musttail` instructions with constant breaks `musttail` invariant + // unless the call itself can be removed + CallInst *CI = dyn_cast(V); + if (CI && CI->isMustTailCall() && !CI->isSafeToRemove()) { + CallSite CS(CI); + Function *F = CS.getCalledFunction(); + + // Don't zap returns of the callee + if (F) + Solver.AddMustTailCallee(F); + + DEBUG(dbgs() << " Can\'t treat the result of musttail call : " << *CI + << " as a constant\n"); + return false; + } + DEBUG(dbgs() << " Constant: " << *Const << " = " << *V << '\n'); // Replaces all of the uses of a variable with uses of the constant. @@ -1802,10 +1841,25 @@ if (!Solver.isArgumentTrackedFunction(&F)) return; - for (BasicBlock &BB : F) + // There is a non-removable musttail call site of this function. Zapping + // returns is not allowed. + if (Solver.isMustTailCallee(&F)) { + DEBUG(dbgs() << "Can't zap returns of the function : " << F.getName() + << " due to present musttail call of it\n"); + return; + } + + for (BasicBlock &BB : F) { + if (CallInst *CI = BB.getTerminatingMustTailCall()) { + DEBUG(dbgs() << "Can't zap return of the block due to present " + << "musttail call : " << *CI << "\n"); + return; + } + if (auto *RI = dyn_cast(BB.getTerminator())) if (!isa(RI->getOperand(0))) ReturnsToZap.push_back(RI); + } } static bool runIPSCCP(Module &M, const DataLayout &DL, Index: test/Transforms/IPConstantProp/musttail-call.ll =================================================================== --- /dev/null +++ test/Transforms/IPConstantProp/musttail-call.ll @@ -0,0 +1,58 @@ +; RUN: opt < %s -ipsccp -S | FileCheck %s +; PR36485 +; musttail call result can\'t be replaced with a constant, unless the call +; can be removed + +declare i32 @external() + +define i8* @start(i8 %v) { + %c1 = icmp eq i8 %v, 0 + br i1 %c1, label %true, label %false +true: + ; CHECK: %ca = musttail call i8* @side_effects(i8 %v) + ; CHECK: ret i8* %ca + %ca = musttail call i8* @side_effects(i8 %v) + ret i8* %ca +false: + %c2 = icmp eq i8 %v, 1 + br i1 %c2, label %c2_true, label %c2_false +c2_true: + %ca1 = musttail call i8* @no_side_effects(i8 %v) + ; CHECK: ret i8* null + ret i8* %ca1 +c2_false: + ; CHECK: %ca2 = musttail call i8* @dont_zap_me(i8 %v) + ; CHECK: ret i8* %ca2 + %ca2 = musttail call i8* @dont_zap_me(i8 %v) + ret i8* %ca2 +} + +define internal i8* @side_effects(i8 %v) { + %i1 = call i32 @external() + + ; since this goes back to `start` the SCPP should be see that the return value + ; is always `null`. + ; The call can't be removed due to `external` call above, though. + + ; CHECK: %ca = musttail call i8* @start(i8 %v) + %ca = musttail call i8* @start(i8 %v) + + ; Thus the result must be returned anyway + ; CHECK: ret i8* %ca + ret i8* %ca +} + +define internal i8* @no_side_effects(i8 %v) readonly nounwind { + ; The call to this function is removed, so the return value must be zapped + ; CHECK: ret i8* undef + ret i8* null +} + +define internal i8* @dont_zap_me(i8 %v) { + %i1 = call i32 @external() + + ; The call to this function cannot be removed due to side effects. Thus the + ; return value should stay as it is, and should not be zapped. + ; CHECK: ret i8* null + ret i8* null +}