diff --git a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp --- a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -561,22 +561,6 @@ } } -// For a ret instruction followed by a musttail call, we cannot insert anything -// in between. Instead we use the musttail call instruction as the insertion -// point. -static Instruction *adjustForMusttailCall(Instruction *I) { - ReturnInst *RI = dyn_cast(I); - if (!RI) - return I; - Instruction *Prev = RI->getPrevNode(); - if (BitCastInst *BCI = dyn_cast_or_null(Prev)) - Prev = BCI->getPrevNode(); - if (CallInst *CI = dyn_cast_or_null(Prev)) - if (CI->isMustTailCall()) - return CI; - return RI; -} - namespace { /// Module analysis for getting various metadata about the module. @@ -985,8 +969,14 @@ void createDynamicAllocasInitStorage(); // ----------------------- Visitors. - /// Collect all Ret instructions. - void visitReturnInst(ReturnInst &RI) { RetVec.push_back(&RI); } + /// Collect all Ret instructions, or the musttail call instruction if it + /// precedes the return instruction. + void visitReturnInst(ReturnInst &RI) { + if (CallInst *CI = RI.getParent()->getTerminatingMustTailCall()) + RetVec.push_back(CI); + else + RetVec.push_back(&RI); + } /// Collect all Resume instructions. void visitResumeInst(ResumeInst &RI) { RetVec.push_back(&RI); } @@ -1021,8 +1011,7 @@ // Unpoison dynamic allocas redzones. void unpoisonDynamicAllocas() { for (Instruction *Ret : RetVec) - unpoisonDynamicAllocasBeforeInst(adjustForMusttailCall(Ret), - DynamicAllocaLayout); + unpoisonDynamicAllocasBeforeInst(Ret, DynamicAllocaLayout); for (Instruction *StackRestoreInst : StackRestoreVec) unpoisonDynamicAllocasBeforeInst(StackRestoreInst, @@ -3333,8 +3322,7 @@ // (Un)poison the stack before all ret instructions. for (Instruction *Ret : RetVec) { - Instruction *Adjusted = adjustForMusttailCall(Ret); - IRBuilder<> IRBRet(Adjusted); + IRBuilder<> IRBRet(Ret); // Mark the current frame as retired. IRBRet.CreateStore(ConstantInt::get(IntptrTy, kRetiredStackFrameMagic), BasePlus0); @@ -3353,7 +3341,7 @@ Value *Cmp = IRBRet.CreateICmpNE(FakeStack, Constant::getNullValue(IntptrTy)); Instruction *ThenTerm, *ElseTerm; - SplitBlockAndInsertIfThenElse(Cmp, Adjusted, &ThenTerm, &ElseTerm); + SplitBlockAndInsertIfThenElse(Cmp, Ret, &ThenTerm, &ElseTerm); IRBuilder<> IRBPoison(ThenTerm); if (StackMallocIdx <= 4) { diff --git a/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp b/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp --- a/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp +++ b/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp @@ -97,13 +97,8 @@ continue; // If T is preceded by a musttail call, that's the real terminator. - Instruction *Prev = T->getPrevNode(); - if (BitCastInst *BCI = dyn_cast_or_null(Prev)) - Prev = BCI->getPrevNode(); - if (CallInst *CI = dyn_cast_or_null(Prev)) { - if (CI->isMustTailCall()) - T = CI; - } + if (CallInst *CI = BB.getTerminatingMustTailCall()) + T = CI; DebugLoc DL; if (DebugLoc TerminatorDL = T->getDebugLoc()) diff --git a/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp b/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp --- a/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp +++ b/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp @@ -41,27 +41,9 @@ if (!isa(TI) && !isa(TI)) continue; - // If the ret instruction is followed by a musttaill call, - // or a bitcast instruction and then a musttail call, we should return - // the musttail call as the insertion point to not break the musttail - // contract. - auto AdjustMustTailCall = [&](Instruction *I) -> Instruction * { - auto *RI = dyn_cast(I); - if (!RI || !RI->getPrevNode()) - return I; - auto *CI = dyn_cast(RI->getPrevNode()); - if (CI && CI->isMustTailCall()) - return CI; - auto *BI = dyn_cast(RI->getPrevNode()); - if (!BI || !BI->getPrevNode()) - return I; - CI = dyn_cast(BI->getPrevNode()); - if (CI && CI->isMustTailCall()) - return CI; - return I; - }; - - Builder.SetInsertPoint(AdjustMustTailCall(TI)); + if (CallInst *CI = CurBB->getTerminatingMustTailCall()) + TI = CI; + Builder.SetInsertPoint(TI); return &Builder; }