Index: include/llvm/IR/Instructions.h =================================================================== --- include/llvm/IR/Instructions.h +++ include/llvm/IR/Instructions.h @@ -2413,6 +2413,10 @@ return getNumOperands() != 0 ? getOperand(0) : nullptr; } + /// Returns a musttail call instruction if one immediately precedes the given + /// return instruction with an optional bitcast instruction between them. + CallInst *getPrecedingMustTailCallIfPresent(); + unsigned getNumSuccessors() const { return 0; } // Methods for support type inquiry through isa, cast, and dyn_cast: Index: lib/IR/Instructions.cpp =================================================================== --- lib/IR/Instructions.cpp +++ lib/IR/Instructions.cpp @@ -658,6 +658,34 @@ ReturnInst::~ReturnInst() { } +CallInst *ReturnInst::getPrecedingMustTailCallIfPresent() { + if (this == getParent()->begin()) + return nullptr; + + Instruction *Prev = getPrevNode(); + if (!Prev) + return nullptr; + + if (Value *RV = getReturnValue()) { + if (RV != Prev) + return nullptr; + + // Look through the optional bitcast. + if (auto *BI = dyn_cast(Prev)) { + RV = BI->getOperand(0); + Prev = BI->getPrevNode(); + if (!Prev || RV != Prev) + return nullptr; + } + } + + if (auto *CI = dyn_cast(Prev)) { + if (CI->isMustTailCall()) + return CI; + } + return nullptr; +} + //===----------------------------------------------------------------------===// // ResumeInst Implementation //===----------------------------------------------------------------------===// Index: lib/Transforms/Instrumentation/MemorySanitizer.cpp =================================================================== --- lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -2275,8 +2275,9 @@ // Allow only tail calls with the same types, otherwise // we may have a false positive: shadow for a non-void RetVal // will get propagated to a void RetVal. - if (Call->isTailCall() && Call->getType() != Call->getParent()->getType()) - Call->setTailCall(false); + if (Call->getTailCallKind() == CallInst::TCK_Tail && + Call->getType() != Call->getParent()->getType()) + Call->setTailCallKind(CallInst::TCK_None); assert(!isa(&I) && "intrinsics are handled elsewhere"); @@ -2348,6 +2349,12 @@ VAHelper->visitCallSite(CS, IRB); } + // If this is a musttail call site, we can't insert propagation code here. + // The return type of the caller must match the callee, so the shadow should + // already be set up for an immediate return. + if (CS.isMustTailCall()) + return; + // Now, get the shadow for the RetVal. if (!I.getType()->isSized()) return; IRBuilder<> IRBBefore(&I); @@ -2381,6 +2388,10 @@ } void visitReturnInst(ReturnInst &I) { + // Don't propagate shadow between musttail calls and the return. + if (I.getPrecedingMustTailCallIfPresent()) + return; + IRBuilder<> IRB(&I); Value *RetVal = I.getReturnValue(); if (!RetVal) return; Index: lib/Transforms/Utils/InlineFunction.cpp =================================================================== --- lib/Transforms/Utils/InlineFunction.cpp +++ lib/Transforms/Utils/InlineFunction.cpp @@ -485,33 +485,6 @@ } } -/// Returns a musttail call instruction if one immediately precedes the given -/// return instruction with an optional bitcast instruction between them. -static CallInst *getPrecedingMustTailCall(ReturnInst *RI) { - Instruction *Prev = RI->getPrevNode(); - if (!Prev) - return nullptr; - - if (Value *RV = RI->getReturnValue()) { - if (RV != Prev) - return nullptr; - - // Look through the optional bitcast. - if (auto *BI = dyn_cast(Prev)) { - RV = BI->getOperand(0); - Prev = BI->getPrevNode(); - if (!Prev || RV != Prev) - return nullptr; - } - } - - if (auto *CI = dyn_cast(Prev)) { - if (CI->isMustTailCall()) - return CI; - } - return nullptr; -} - /// InlineFunction - This function inlines the called function into the basic /// block of the caller. This returns false if it is not possible to inline /// this call. The program is still in a well defined state if this occurs @@ -764,7 +737,7 @@ for (ReturnInst *RI : Returns) { // Don't insert llvm.lifetime.end calls between a musttail call and a // return. The return kills all local allocas. - if (InlinedMustTailCalls && getPrecedingMustTailCall(RI)) + if (InlinedMustTailCalls && RI->getPrecedingMustTailCallIfPresent()) continue; IRBuilder<>(RI).CreateLifetimeEnd(AI, AllocaSize); } @@ -788,7 +761,7 @@ for (ReturnInst *RI : Returns) { // Don't insert llvm.stackrestore calls between a musttail call and a // return. The return will restore the stack pointer. - if (InlinedMustTailCalls && getPrecedingMustTailCall(RI)) + if (InlinedMustTailCalls && RI->getPrecedingMustTailCallIfPresent()) continue; IRBuilder<>(RI).CreateCall(StackRestore, SavedPtr); } @@ -811,7 +784,7 @@ // Handle the returns preceded by musttail calls separately. SmallVector NormalReturns; for (ReturnInst *RI : Returns) { - CallInst *ReturnedMustTail = getPrecedingMustTailCall(RI); + CallInst *ReturnedMustTail = RI->getPrecedingMustTailCallIfPresent(); if (!ReturnedMustTail) { NormalReturns.push_back(RI); continue; Index: test/Instrumentation/MemorySanitizer/msan_basic.ll =================================================================== --- test/Instrumentation/MemorySanitizer/msan_basic.ll +++ test/Instrumentation/MemorySanitizer/msan_basic.ll @@ -825,3 +825,30 @@ ; CHECK: store i64 16, i64* @__msan_va_arg_overflow_size_tls ; CHECK: call void (i32, ...)* @VAArgStructFn ; CHECK: ret void + +declare i32 @InnerTailCall(i32 %a) + +define void @MismatchedTailCallType(i32 %a) { + %b = call i32 @InnerTailCall(i32 %a) + ret void +} + +; Test that 'tail' gets stripped off here so that shadow propagation is correct. + +; CHECK-LABEL: define void @MismatchedTailCallType +; CHECK-NOT: tail +; CHECK: call i32 @InnerTailCall +; CHECK: ret void + +declare i32 @InnerMustTailCall(i32 %a) + +define i32 @MustTailCall(i32 %a) { + %b = musttail call i32 @InnerMustTailCall(i32 %a) + ret i32 %b +} + +; Test that 'musttail' is preserved. The ABI should make this work. + +; CHECK-LABEL: define i32 @MustTailCall +; CHECK: musttail call i32 @InnerMustTailCall +; CHECK-NEXT: ret i32