diff --git a/llvm/lib/CodeGen/SafeStack.cpp b/llvm/lib/CodeGen/SafeStack.cpp --- a/llvm/lib/CodeGen/SafeStack.cpp +++ b/llvm/lib/CodeGen/SafeStack.cpp @@ -151,7 +151,7 @@ Value *getStackGuard(IRBuilder<> &IRB, Function &F); /// Load stack guard from the frame and check if it has changed. - void checkStackGuard(IRBuilder<> &IRB, Function &F, ReturnInst &RI, + void checkStackGuard(IRBuilder<> &IRB, Function &F, Instruction &RI, AllocaInst *StackGuardSlot, Value *StackGuard); /// Find all static allocas, dynamic allocas, return instructions and @@ -160,7 +160,7 @@ void findInsts(Function &F, SmallVectorImpl &StaticAllocas, SmallVectorImpl &DynamicAllocas, SmallVectorImpl &ByValArguments, - SmallVectorImpl &Returns, + SmallVectorImpl &Returns, SmallVectorImpl &StackRestorePoints); /// Calculate the allocation size of a given alloca. Returns 0 if the @@ -168,15 +168,13 @@ uint64_t getStaticAllocaAllocationSize(const AllocaInst* AI); /// Allocate space for all static allocas in \p StaticAllocas, - /// replace allocas with pointers into the unsafe stack and generate code to - /// restore the stack pointer before all return instructions in \p Returns. + /// replace allocas with pointers into the unsafe stack. /// /// \returns A pointer to the top of the unsafe stack after all unsafe static /// allocas are allocated. Value *moveStaticAllocasToUnsafeStack(IRBuilder<> &IRB, Function &F, ArrayRef StaticAllocas, ArrayRef ByValArguments, - ArrayRef Returns, Instruction *BasePointer, AllocaInst *StackGuardSlot); @@ -383,7 +381,7 @@ SmallVectorImpl &StaticAllocas, SmallVectorImpl &DynamicAllocas, SmallVectorImpl &ByValArguments, - SmallVectorImpl &Returns, + SmallVectorImpl &Returns, SmallVectorImpl &StackRestorePoints) { for (Instruction &I : instructions(&F)) { if (auto AI = dyn_cast(&I)) { @@ -401,7 +399,10 @@ DynamicAllocas.push_back(AI); } } else if (auto RI = dyn_cast(&I)) { - Returns.push_back(RI); + if (CallInst *CI = I.getParent()->getTerminatingMustTailCall()) + Returns.push_back(CI); + else + Returns.push_back(RI); } else if (auto CI = dyn_cast(&I)) { // setjmps require stack restore. if (CI->getCalledFunction() && CI->canReturnTwice()) @@ -465,7 +466,7 @@ return DynamicTop; } -void SafeStack::checkStackGuard(IRBuilder<> &IRB, Function &F, ReturnInst &RI, +void SafeStack::checkStackGuard(IRBuilder<> &IRB, Function &F, Instruction &RI, AllocaInst *StackGuardSlot, Value *StackGuard) { Value *V = IRB.CreateLoad(StackPtrTy, StackGuardSlot); Value *Cmp = IRB.CreateICmpNE(StackGuard, V); @@ -490,8 +491,8 @@ /// prologue into a local variable and restore it in the epilogue. Value *SafeStack::moveStaticAllocasToUnsafeStack( IRBuilder<> &IRB, Function &F, ArrayRef StaticAllocas, - ArrayRef ByValArguments, ArrayRef Returns, - Instruction *BasePointer, AllocaInst *StackGuardSlot) { + ArrayRef ByValArguments, Instruction *BasePointer, + AllocaInst *StackGuardSlot) { if (StaticAllocas.empty() && ByValArguments.empty()) return BasePointer; @@ -759,7 +760,7 @@ SmallVector StaticAllocas; SmallVector DynamicAllocas; SmallVector ByValArguments; - SmallVector Returns; + SmallVector Returns; // Collect all points where stack gets unwound and needs to be restored // This is only necessary because the runtime (setjmp and unwind code) is @@ -812,7 +813,7 @@ StackGuardSlot = IRB.CreateAlloca(StackPtrTy, nullptr); IRB.CreateStore(StackGuard, StackGuardSlot); - for (ReturnInst *RI : Returns) { + for (Instruction *RI : Returns) { IRBuilder<> IRBRet(RI); checkStackGuard(IRBRet, F, *RI, StackGuardSlot, StackGuard); } @@ -820,9 +821,8 @@ // The top of the unsafe stack after all unsafe static allocas are // allocated. - Value *StaticTop = - moveStaticAllocasToUnsafeStack(IRB, F, StaticAllocas, ByValArguments, - Returns, BasePointer, StackGuardSlot); + Value *StaticTop = moveStaticAllocasToUnsafeStack( + IRB, F, StaticAllocas, ByValArguments, BasePointer, StackGuardSlot); // Safe stack object that stores the current unsafe stack top. It is updated // as unsafe dynamic (non-constant-sized) allocas are allocated and freed. @@ -838,7 +838,7 @@ DynamicAllocas); // Restore the unsafe stack pointer before each return. - for (ReturnInst *RI : Returns) { + for (Instruction *RI : Returns) { IRB.SetInsertPoint(RI); IRB.CreateStore(BasePointer, UnsafeStackPtr); } diff --git a/llvm/test/Transforms/SafeStack/X86/musttail.ll b/llvm/test/Transforms/SafeStack/X86/musttail.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/SafeStack/X86/musttail.ll @@ -0,0 +1,46 @@ +; To test that safestack does not break the musttail call contract. +; +; RUN: opt < %s --safe-stack -S | FileCheck %s + +target triple = "x86_64-unknown-linux-gnu" + +declare i32 @foo(i32* %p) +declare void @alloca_test_use([10 x i8]*) + +define i32 @call_foo(i32* %a) safestack { +; CHECK-LABEL: @call_foo( +; CHECK-NEXT: [[UNSAFE_STACK_PTR:%.*]] = load i8*, i8** @__safestack_unsafe_stack_ptr, align 8 +; CHECK-NEXT: [[UNSAFE_STACK_STATIC_TOP:%.*]] = getelementptr i8, i8* [[UNSAFE_STACK_PTR]], i32 -16 +; CHECK-NEXT: store i8* [[UNSAFE_STACK_STATIC_TOP]], i8** @__safestack_unsafe_stack_ptr, align 8 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i8, i8* [[UNSAFE_STACK_PTR]], i32 -10 +; CHECK-NEXT: [[X_UNSAFE:%.*]] = bitcast i8* [[TMP1]] to [10 x i8]* +; CHECK-NEXT: call void @alloca_test_use([10 x i8]* [[X_UNSAFE]]) +; CHECK-NEXT: store i8* [[UNSAFE_STACK_PTR]], i8** @__safestack_unsafe_stack_ptr, align 8 +; CHECK-NEXT: [[R:%.*]] = musttail call i32 @foo(i32* [[A:%.*]]) +; CHECK-NEXT: ret i32 [[R]] +; + %x = alloca [10 x i8], align 1 + call void @alloca_test_use([10 x i8]* %x) + %r = musttail call i32 @foo(i32* %a) + ret i32 %r +} + +define i32 @call_foo_cast(i32* %a) safestack { +; CHECK-LABEL: @call_foo_cast( +; CHECK-NEXT: [[UNSAFE_STACK_PTR:%.*]] = load i8*, i8** @__safestack_unsafe_stack_ptr, align 8 +; CHECK-NEXT: [[UNSAFE_STACK_STATIC_TOP:%.*]] = getelementptr i8, i8* [[UNSAFE_STACK_PTR]], i32 -16 +; CHECK-NEXT: store i8* [[UNSAFE_STACK_STATIC_TOP]], i8** @__safestack_unsafe_stack_ptr, align 8 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i8, i8* [[UNSAFE_STACK_PTR]], i32 -10 +; CHECK-NEXT: [[X_UNSAFE:%.*]] = bitcast i8* [[TMP1]] to [10 x i8]* +; CHECK-NEXT: call void @alloca_test_use([10 x i8]* [[X_UNSAFE]]) +; CHECK-NEXT: store i8* [[UNSAFE_STACK_PTR]], i8** @__safestack_unsafe_stack_ptr, align 8 +; CHECK-NEXT: [[R:%.*]] = musttail call i32 @foo(i32* [[A:%.*]]) +; CHECK-NEXT: [[T:%.*]] = bitcast i32 [[R]] to i32 +; CHECK-NEXT: ret i32 [[T]] +; + %x = alloca [10 x i8], align 1 + call void @alloca_test_use([10 x i8]* %x) + %r = musttail call i32 @foo(i32* %a) + %t = bitcast i32 %r to i32 + ret i32 %t +}