Index: include/llvm/IR/Instructions.h =================================================================== --- include/llvm/IR/Instructions.h +++ include/llvm/IR/Instructions.h @@ -2413,6 +2413,10 @@ return getNumOperands() != 0 ? getOperand(0) : nullptr; } + /// Returns a musttail call instruction if one immediately precedes the given + /// return instruction with an optional bitcast instruction between them. + CallInst *getPrecedingMustTailCallIfPresent(); + unsigned getNumSuccessors() const { return 0; } // Methods for support type inquiry through isa, cast, and dyn_cast: Index: lib/IR/Instructions.cpp =================================================================== --- lib/IR/Instructions.cpp +++ lib/IR/Instructions.cpp @@ -658,6 +658,34 @@ ReturnInst::~ReturnInst() { } +CallInst *ReturnInst::getPrecedingMustTailCallIfPresent() { + if (this == getParent()->begin()) + return nullptr; + + Instruction *Prev = getPrevNode(); + if (!Prev) + return nullptr; + + if (Value *RV = getReturnValue()) { + if (RV != Prev) + return nullptr; + + // Look through the optional bitcast. + if (auto *BI = dyn_cast(Prev)) { + RV = BI->getOperand(0); + Prev = BI->getPrevNode(); + if (!Prev || RV != Prev) + return nullptr; + } + } + + if (auto *CI = dyn_cast(Prev)) { + if (CI->isMustTailCall()) + return CI; + } + return nullptr; +} + //===----------------------------------------------------------------------===// // ResumeInst Implementation //===----------------------------------------------------------------------===// Index: lib/Transforms/Instrumentation/MemorySanitizer.cpp =================================================================== --- lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -2342,6 +2342,12 @@ VAHelper->visitCallSite(CS, IRB); } + // If this is a musttail call site, we can't insert propagation code here. + // The return type of the caller must match the callee, so the shadow should + // already be set up for an immediate return. + if (CS.isMustTailCall()) + return; + // Now, get the shadow for the RetVal. if (!I.getType()->isSized()) return; IRBuilder<> IRBBefore(&I); @@ -2375,6 +2381,10 @@ } void visitReturnInst(ReturnInst &I) { + // Don't propagate shadow between musttail calls and the return. + if (I.getPrecedingMustTailCallIfPresent()) + return; + IRBuilder<> IRB(&I); Value *RetVal = I.getReturnValue(); if (!RetVal) return; Index: lib/Transforms/Utils/InlineFunction.cpp =================================================================== --- lib/Transforms/Utils/InlineFunction.cpp +++ lib/Transforms/Utils/InlineFunction.cpp @@ -486,33 +486,6 @@ } } -/// Returns a musttail call instruction if one immediately precedes the given -/// return instruction with an optional bitcast instruction between them. -static CallInst *getPrecedingMustTailCall(ReturnInst *RI) { - Instruction *Prev = RI->getPrevNode(); - if (!Prev) - return nullptr; - - if (Value *RV = RI->getReturnValue()) { - if (RV != Prev) - return nullptr; - - // Look through the optional bitcast. - if (auto *BI = dyn_cast(Prev)) { - RV = BI->getOperand(0); - Prev = BI->getPrevNode(); - if (!Prev || RV != Prev) - return nullptr; - } - } - - if (auto *CI = dyn_cast(Prev)) { - if (CI->isMustTailCall()) - return CI; - } - return nullptr; -} - /// InlineFunction - This function inlines the called function into the basic /// block of the caller. This returns false if it is not possible to inline /// this call. The program is still in a well defined state if this occurs @@ -765,7 +738,7 @@ for (ReturnInst *RI : Returns) { // Don't insert llvm.lifetime.end calls between a musttail call and a // return. The return kills all local allocas. - if (InlinedMustTailCalls && getPrecedingMustTailCall(RI)) + if (InlinedMustTailCalls && RI->getPrecedingMustTailCallIfPresent()) continue; IRBuilder<>(RI).CreateLifetimeEnd(AI, AllocaSize); } @@ -789,7 +762,7 @@ for (ReturnInst *RI : Returns) { // Don't insert llvm.stackrestore calls between a musttail call and a // return. The return will restore the stack pointer. - if (InlinedMustTailCalls && getPrecedingMustTailCall(RI)) + if (InlinedMustTailCalls && RI->getPrecedingMustTailCallIfPresent()) continue; IRBuilder<>(RI).CreateCall(StackRestore, SavedPtr); } @@ -812,7 +785,7 @@ // Handle the returns preceded by musttail calls separately. SmallVector NormalReturns; for (ReturnInst *RI : Returns) { - CallInst *ReturnedMustTail = getPrecedingMustTailCall(RI); + CallInst *ReturnedMustTail = RI->getPrecedingMustTailCallIfPresent(); if (!ReturnedMustTail) { NormalReturns.push_back(RI); continue; Index: test/Instrumentation/MemorySanitizer/msan_basic.ll =================================================================== --- test/Instrumentation/MemorySanitizer/msan_basic.ll +++ test/Instrumentation/MemorySanitizer/msan_basic.ll @@ -860,3 +860,16 @@ ; CHECK-LABEL: define void @MismatchedReturnTypeTailCall ; CHECK: tail call i32 @InnerTailCall ; CHECK: ret void + +declare i32 @InnerMustTailCall(i32 %a) + +define i32 @MustTailCall(i32 %a) { + %b = musttail call i32 @InnerMustTailCall(i32 %a) + ret i32 %b +} + +; Test that 'musttail' is preserved. The ABI should make this work. + +; CHECK-LABEL: define i32 @MustTailCall +; CHECK: musttail call i32 @InnerMustTailCall +; CHECK-NEXT: ret i32