Index: lib/Target/AArch64/AArch64.td =================================================================== --- lib/Target/AArch64/AArch64.td +++ lib/Target/AArch64/AArch64.td @@ -131,6 +131,9 @@ "Reserve X"#i#", making it unavailable " "as a GPR">; +def FeatureProtectX18 : SubtargetFeature<"protect-x18", "ProtectX18", "true", + "Guard against called code changing x18">; + foreach i = {8-15,18} in def FeatureCallSavedX#i : SubtargetFeature<"call-saved-x"#i, "CustomCallSavedXRegs["#i#"]", "true", "Make X"#i#" callee saved.">; Index: lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.cpp +++ lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3228,6 +3228,17 @@ } } + if (Subtarget->shouldProtectX18()) { + // If protecting X18, make a copy of the value of X18 on entry. + MVT PtrTy = getPointerTy(DAG.getDataLayout()); + unsigned Reg = MF.getRegInfo().createVirtualRegister(getRegClassFor(PtrTy)); + FuncInfo->setX18BackupReg(Reg); + + SDValue X18Reg = DAG.getRegister(AArch64::X18, MVT::i64); + SDValue Copy = DAG.getCopyToReg(DAG.getEntryNode(), DL, Reg, X18Reg); + Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Copy, Chain); + } + unsigned StackArgSize = CCInfo.getNextStackOffset(); bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt; if (DoesCalleeRestoreStack(CallConv, TailCallOpt)) { @@ -3418,6 +3429,17 @@ CallingConv::ID CallerCC = CallerF.getCallingConv(); bool CCMatch = CallerCC == CalleeCC; + if (Subtarget->shouldProtectX18()) { + // If protecting X18, we can't do a tail call to an unknown function + // (functions from the same translation unit are deemed safe), as we'd + // need to restore X18 after the function. + bool SameTU = false; + if (auto *G = dyn_cast(Callee)) + SameTU = !G->getGlobal()->isDeclarationForLinker(); + if (!SameTU) + return false; + } + // Byval parameters hand the function a pointer directly into the stack area // we want to reuse during a tail call. Working around this *is* possible (see // X86) but less efficient and uglier in LowerCall. @@ -3830,8 +3852,10 @@ // If the callee is a GlobalAddress/ExternalSymbol node (quite common, every // direct call is) turn it into a TargetGlobalAddress/TargetExternalSymbol // node so that legalize doesn't hack it. + bool SameTU = false; if (auto *G = dyn_cast(Callee)) { auto GV = G->getGlobal(); + SameTU = !GV->isDeclarationForLinker(); if (Subtarget->classifyGlobalFunctionReference(GV, getTargetMachine()) == AArch64II::MO_GOT) { Callee = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, AArch64II::MO_GOT); @@ -3932,9 +3956,18 @@ // Handle result values, copying them out of physregs into vregs that we // return. - return LowerCallResult(Chain, InFlag, CallConv, IsVarArg, Ins, DL, DAG, - InVals, IsThisReturn, - IsThisReturn ? OutVals[0] : SDValue()); + Chain = + LowerCallResult(Chain, InFlag, CallConv, IsVarArg, Ins, DL, DAG, InVals, + IsThisReturn, IsThisReturn ? OutVals[0] : SDValue()); + + // If protecting X18, restore X18 after a call to an unknown function + // (functions from the same translation unit are deemed safe). + if (Subtarget->shouldProtectX18() && !SameTU) + Chain = DAG.getCopyToReg( + Chain, DL, AArch64::X18, + DAG.getCopyFromReg(Chain, DL, FuncInfo->getX18BackupReg(), PtrVT)); + + return Chain; } bool AArch64TargetLowering::CanLowerReturn( Index: lib/Target/AArch64/AArch64MachineFunctionInfo.h =================================================================== --- lib/Target/AArch64/AArch64MachineFunctionInfo.h +++ lib/Target/AArch64/AArch64MachineFunctionInfo.h @@ -96,6 +96,8 @@ /// which the sret argument is passed. unsigned SRetReturnReg = 0; + unsigned X18BackupReg = 0; + /// Has a value when it is known whether or not the function uses a /// redzone, and no value otherwise. /// Initialized during frame lowering, unless the function has the noredzone @@ -173,6 +175,9 @@ unsigned getSRetReturnReg() const { return SRetReturnReg; } void setSRetReturnReg(unsigned Reg) { SRetReturnReg = Reg; } + unsigned getX18BackupReg() const { return X18BackupReg; } + void setX18BackupReg(unsigned Reg) { X18BackupReg = Reg; } + unsigned getJumpTableEntrySize(int Idx) const { auto It = JumpTableEntryInfo.find(Idx); if (It != JumpTableEntryInfo.end()) Index: lib/Target/AArch64/AArch64Subtarget.h =================================================================== --- lib/Target/AArch64/AArch64Subtarget.h +++ lib/Target/AArch64/AArch64Subtarget.h @@ -190,6 +190,9 @@ // ReserveXRegister[i] - X#i is not available as a general purpose register. BitVector ReserveXRegister; + // ProtectX18 - try to guard against foreign code changing X18. + bool ProtectX18 = false; + // CustomCallUsedXRegister[i] - X#i call saved. BitVector CustomCallSavedXRegs; @@ -283,6 +286,7 @@ bool isXRegisterReserved(size_t i) const { return ReserveXRegister[i]; } unsigned getNumXRegisterReserved() const { return ReserveXRegister.count(); } + bool shouldProtectX18() const { return ProtectX18; } bool isXRegCustomCalleeSaved(size_t i) const { return CustomCallSavedXRegs[i]; } Index: test/CodeGen/AArch64/protect-x18.ll =================================================================== --- /dev/null +++ test/CodeGen/AArch64/protect-x18.ll @@ -0,0 +1,68 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=aarch64-linux-gnu | FileCheck %s + +define dso_local void @localfunc() #0 { +; CHECK-LABEL: localfunc: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: //APP +; CHECK-NEXT: //NO_APP +; CHECK-NEXT: ret +entry: + tail call void asm sideeffect "", ""() #1 + ret void +} + +define dso_local i32 @func(i32 ()* nocapture %ptr) #0 { +; CHECK-LABEL: func: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: str x20, [sp, #-32]! // 8-byte Folded Spill +; CHECK-NEXT: stp x19, x30, [sp, #16] // 16-byte Folded Spill +; CHECK-NEXT: mov x19, x0 +; CHECK-NEXT: mov x20, x18 +; CHECK-NEXT: bl localfunc +; CHECK-NEXT: blr x19 +; CHECK-NEXT: mov w19, w0 +; CHECK-NEXT: mov x18, x20 +; CHECK-NEXT: bl other +; CHECK-NEXT: mov x18, x20 +; CHECK-NEXT: bl localfunc +; CHECK-NEXT: mov w0, w19 +; CHECK-NEXT: ldp x19, x30, [sp, #16] // 16-byte Folded Reload +; CHECK-NEXT: ldr x20, [sp], #32 // 8-byte Folded Reload +; CHECK-NEXT: ret +entry: + tail call void @localfunc() + %call = tail call i32 %ptr() + tail call void @other() + tail call void @localfunc() + ret i32 %call +} + +declare dso_local void @other() + +define dso_local void @tailcall_other() #0 { +; CHECK-LABEL: tailcall_other: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: stp x19, x30, [sp, #-16]! // 16-byte Folded Spill +; CHECK-NEXT: mov x19, x18 +; CHECK-NEXT: bl other +; CHECK-NEXT: mov x18, x19 +; CHECK-NEXT: ldp x19, x30, [sp], #16 // 16-byte Folded Reload +; CHECK-NEXT: ret +entry: + tail call void @other() + ret void +} + +define dso_local void @tailcall_local() #0 { +; CHECK-LABEL: tailcall_local: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: b localfunc +entry: + tail call void @localfunc() + ret void +} + +attributes #0 = { nounwind "target-features"="+reserve-x18,+protect-x18" } + +attributes #1 = { nounwind }