diff --git a/llvm/lib/Target/ARM/ARMMachineFunctionInfo.h b/llvm/lib/Target/ARM/ARMMachineFunctionInfo.h --- a/llvm/lib/Target/ARM/ARMMachineFunctionInfo.h +++ b/llvm/lib/Target/ARM/ARMMachineFunctionInfo.h @@ -142,6 +142,17 @@ /// con/destructors). bool PreservesR0 = false; + /// True if the function should sign its return address. + bool SignReturnAddress = false; + + /// True if the fucntion should sign its return address, even if LR is not + /// saved. + bool SignReturnAddressAll = false; + + /// True if BTI instructions should be placed at potential indirect jump + /// destinations. + bool BranchTargetEnforcement = false; + public: ARMFunctionInfo() = default; @@ -268,6 +279,20 @@ void setPreservesR0() { PreservesR0 = true; } bool getPreservesR0() const { return PreservesR0; } + + bool shouldSignReturnAddress() const { + return shouldSignReturnAddress(LRSpilled); + } + + bool shouldSignReturnAddress(bool SpillsLR) const { + if (!SignReturnAddress) + return false; + if (SignReturnAddressAll) + return true; + return LRSpilled; + } + + bool branchTargetEnforcement() const { return BranchTargetEnforcement; } }; } // end namespace llvm diff --git a/llvm/lib/Target/ARM/ARMMachineFunctionInfo.cpp b/llvm/lib/Target/ARM/ARMMachineFunctionInfo.cpp --- a/llvm/lib/Target/ARM/ARMMachineFunctionInfo.cpp +++ b/llvm/lib/Target/ARM/ARMMachineFunctionInfo.cpp @@ -13,8 +13,63 @@ void ARMFunctionInfo::anchor() {} +static bool GetBranchTargetEnforcement(MachineFunction &MF) { + const auto &Subtarget = MF.getSubtarget(); + if (!Subtarget.isMClass() || !Subtarget.hasV7Ops()) + return false; + + const Function &F = MF.getFunction(); + if (!F.hasFnAttribute("branch-target-enforcement")) { + if (const auto *BTE = mdconst::extract_or_null( + F.getParent()->getModuleFlag("branch-target-enforcement"))) + return BTE->getZExtValue(); + return false; + } + + const StringRef BTIEnable = + F.getFnAttribute("branch-target-enforcement").getValueAsString(); + assert(BTIEnable.equals_insensitive("true") || + BTIEnable.equals_insensitive("false")); + return BTIEnable.equals_insensitive("true"); +} + +// The pair returns values for the ARMFunctionInfo members +// SignReturnAddress and SignReturnAddressAll respectively. +static std::pair GetSignReturnAddress(const Function &F) { + if (!F.hasFnAttribute("sign-return-address")) { + const Module &M = *F.getParent(); + if (const auto *Sign = mdconst::extract_or_null( + M.getModuleFlag("sign-return-address"))) { + if (Sign->getZExtValue()) { + if (const auto *All = mdconst::extract_or_null( + M.getModuleFlag("sign-return-address-all"))) + return {true, All->getZExtValue()}; + return {true, false}; + } + } + return {false, false}; + } + + StringRef Scope = F.getFnAttribute("sign-return-address").getValueAsString(); + if (Scope.equals("none")) + return {false, false}; + + if (Scope.equals("all")) + return {true, true}; + + assert(Scope.equals("non-leaf")); + return {true, false}; +} + ARMFunctionInfo::ARMFunctionInfo(MachineFunction &MF) : isThumb(MF.getSubtarget().isThumb()), hasThumb2(MF.getSubtarget().hasThumb2()), IsCmseNSEntry(MF.getFunction().hasFnAttribute("cmse_nonsecure_entry")), - IsCmseNSCall(MF.getFunction().hasFnAttribute("cmse_nonsecure_call")) {} + IsCmseNSCall(MF.getFunction().hasFnAttribute("cmse_nonsecure_call")), + BranchTargetEnforcement(GetBranchTargetEnforcement(MF)) { + + const auto &Subtarget = MF.getSubtarget(); + if (Subtarget.isMClass() && Subtarget.hasV7Ops()) + std::tie(SignReturnAddress, SignReturnAddressAll) = + GetSignReturnAddress(MF.getFunction()); +}