diff --git a/llvm/include/llvm/CodeGen/MachineBasicBlock.h b/llvm/include/llvm/CodeGen/MachineBasicBlock.h --- a/llvm/include/llvm/CodeGen/MachineBasicBlock.h +++ b/llvm/include/llvm/CodeGen/MachineBasicBlock.h @@ -153,12 +153,18 @@ /// LLVM IR. bool IsEHScopeEntry = false; + /// Indicates if this is a target block of a catchret. + bool IsEHCatchRetTarget = false; + /// Indicate that this basic block is the entry block of an EH funclet. bool IsEHFuncletEntry = false; /// Indicate that this basic block is the entry block of a cleanup funclet. bool IsCleanupFuncletEntry = false; + /// Indicates that this block is an EHCont target. + bool IsEHContTarget = false; + /// With basic block sections, this stores the Section ID of the basic block. MBBSectionID SectionID{0}; @@ -175,6 +181,9 @@ /// is only computed once and is cached. mutable MCSymbol *CachedMCSymbol = nullptr; + /// Cached MCSymbol for this block (used if IsEHContTarget). + mutable MCSymbol *CachedEHContMCSymbol = nullptr; + /// Marks the end of the basic block. Used during basic block sections to /// calculate the size of the basic block, or the BB section ending with it. mutable MCSymbol *CachedEndMCSymbol = nullptr; @@ -445,6 +454,12 @@ /// that used to have a catchpad or cleanuppad instruction in the LLVM IR. void setIsEHScopeEntry(bool V = true) { IsEHScopeEntry = V; } + /// Returns true if this is a target block of a catchret. + bool isEHCatchRetTarget() const { return IsEHCatchRetTarget; } + + /// Indicates if this is a target block of a catchret. + void setIsEHCatchRetTarget(bool V = true) { IsEHCatchRetTarget = V; } + /// Returns true if this is the entry block of an EH funclet. bool isEHFuncletEntry() const { return IsEHFuncletEntry; } @@ -457,6 +472,12 @@ /// Indicates if this is the entry block of a cleanup funclet. void setIsCleanupFuncletEntry(bool V = true) { IsCleanupFuncletEntry = V; } + /// Returns true if this block is an EHCont target. + bool isEHContTarget() const { return IsEHContTarget; } + + /// Indicates if this is an EHCont target. + void setIsEHContTarget(bool V = true) { IsEHContTarget = V; } + /// Returns true if this block begins any section. bool isBeginSection() const { return IsBeginSection; } @@ -910,6 +931,9 @@ /// Return the MCSymbol for this basic block. MCSymbol *getSymbol() const; + /// Return the EHCont Symbol for this basic block. + MCSymbol *getEHContSymbol() const; + Optional getIrrLoopHeaderWeight() const { return IrrLoopHeaderWeight; } diff --git a/llvm/include/llvm/CodeGen/MachineFunction.h b/llvm/include/llvm/CodeGen/MachineFunction.h --- a/llvm/include/llvm/CodeGen/MachineFunction.h +++ b/llvm/include/llvm/CodeGen/MachineFunction.h @@ -321,6 +321,10 @@ /// construct a table of valid longjmp targets for Windows Control Flow Guard. std::vector LongjmpTargets; + /// List of basic blocks that are the target of catchrets. Used to construct + /// a table of valid targets for Windows EHCont Guard. + std::vector CatchretTargets; + /// \name Exception Handling /// \{ @@ -341,6 +345,7 @@ bool CallsEHReturn = false; bool CallsUnwindInit = false; + bool HasEHCatchRet = false; bool HasEHScopes = false; bool HasEHFunclets = false; @@ -930,6 +935,18 @@ /// Control Flow Guard. void addLongjmpTarget(MCSymbol *Target) { LongjmpTargets.push_back(Target); } + /// Returns a reference to a list of symbols that we have catchrets. + /// Used to construct the longjmp target table used by Windows EHCont Guard. + const std::vector &getCatchretTargets() const { + return CatchretTargets; + } + + /// Add the specified symbol to the list of valid catchret targets for Windows + /// EHCont Guard. + void addCatchretTarget(MCSymbol *Target) { + CatchretTargets.push_back(Target); + } + /// \name Exception Handling /// \{ @@ -939,6 +956,9 @@ bool callsUnwindInit() const { return CallsUnwindInit; } void setCallsUnwindInit(bool b) { CallsUnwindInit = b; } + bool hasEHCatchRet() const { return HasEHCatchRet; } + void setHasEHCatchRet(bool V) { HasEHCatchRet = V; } + bool hasEHScopes() const { return HasEHScopes; } void setHasEHScopes(bool V) { HasEHScopes = V; } diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h --- a/llvm/include/llvm/CodeGen/Passes.h +++ b/llvm/include/llvm/CodeGen/Passes.h @@ -464,6 +464,10 @@ /// \see CFGuardLongjmp.cpp FunctionPass *createCFGuardLongjmpPass(); + /// Creates EHContGuard catchret target identification pass. + /// \see EHContGuardCatchret.cpp + FunctionPass *createEHContGuardCatchretPass(); + /// Create Hardware Loop pass. \see HardwareLoops.cpp FunctionPass *createHardwareLoopsPass(); diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -148,6 +148,7 @@ void initializeEarlyMachineLICMPass(PassRegistry&); void initializeEarlyTailDuplicatePass(PassRegistry&); void initializeEdgeBundlesPass(PassRegistry&); +void initializeEHContGuardCatchretPass(PassRegistry &); void initializeEliminateAvailableExternallyLegacyPassPass(PassRegistry&); void initializeEntryExitInstrumenterPass(PassRegistry&); void initializeExpandMemCmpPassPass(PassRegistry&); diff --git a/llvm/include/llvm/MC/MCObjectFileInfo.h b/llvm/include/llvm/MC/MCObjectFileInfo.h --- a/llvm/include/llvm/MC/MCObjectFileInfo.h +++ b/llvm/include/llvm/MC/MCObjectFileInfo.h @@ -218,6 +218,7 @@ MCSection *PDataSection = nullptr; MCSection *XDataSection = nullptr; MCSection *SXDataSection = nullptr; + MCSection *GEHContSection = nullptr; MCSection *GFIDsSection = nullptr; MCSection *GIATsSection = nullptr; MCSection *GLJMPSection = nullptr; @@ -405,6 +406,7 @@ MCSection *getPDataSection() const { return PDataSection; } MCSection *getXDataSection() const { return XDataSection; } MCSection *getSXDataSection() const { return SXDataSection; } + MCSection *getGEHContSection() const { return GEHContSection; } MCSection *getGFIDsSection() const { return GFIDsSection; } MCSection *getGIATsSection() const { return GIATsSection; } MCSection *getGLJMPSection() const { return GLJMPSection; } diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp --- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp @@ -3201,6 +3201,10 @@ } } + if (MBB.isEHContTarget()) { + OutStreamer->emitLabel(MBB.getEHContSymbol()); + } + // With BB sections, each basic block must handle CFI information on its own // if it begins a section (Entry block is handled separately by // AsmPrinterHandler::beginFunction). diff --git a/llvm/lib/CodeGen/AsmPrinter/WinException.h b/llvm/lib/CodeGen/AsmPrinter/WinException.h --- a/llvm/lib/CodeGen/AsmPrinter/WinException.h +++ b/llvm/lib/CodeGen/AsmPrinter/WinException.h @@ -44,6 +44,9 @@ /// The section of the last funclet start. MCSection *CurrentFuncletTextSection = nullptr; + /// The list of symbols to add to the ehcont section + std::vector EHContTargets; + void emitCSpecificHandlerTable(const MachineFunction *MF); void emitSEHActionsForRange(const WinEHFuncInfo &FuncInfo, diff --git a/llvm/lib/CodeGen/AsmPrinter/WinException.cpp b/llvm/lib/CodeGen/AsmPrinter/WinException.cpp --- a/llvm/lib/CodeGen/AsmPrinter/WinException.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/WinException.cpp @@ -55,6 +55,14 @@ for (const Function &F : *M) if (F.hasFnAttribute("safeseh")) OS.EmitCOFFSafeSEH(Asm->getSymbol(&F)); + + if (M->getModuleFlag("ehcontguard") && !EHContTargets.empty()) { + // Emit the symbol index of each ehcont target. + OS.SwitchSection(Asm->OutContext.getObjectFileInfo()->getGEHContSection()); + for (const MCSymbol *S : EHContTargets) { + OS.EmitCOFFSymbolIndex(S); + } + } } void WinException::beginFunction(const MachineFunction *MF) { @@ -164,6 +172,12 @@ Asm->OutStreamer->PopSection(); } + + if (!MF->getCatchretTargets().empty()) { + // Copy the function's catchret targets to a module-level list. + EHContTargets.insert(EHContTargets.end(), MF->getCatchretTargets().begin(), + MF->getCatchretTargets().end()); + } } /// Retrieve the MCSymbol for a GlobalValue or MachineBasicBlock. diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt --- a/llvm/lib/CodeGen/CMakeLists.txt +++ b/llvm/lib/CodeGen/CMakeLists.txt @@ -24,6 +24,7 @@ DwarfEHPrepare.cpp EarlyIfConversion.cpp EdgeBundles.cpp + EHContGuardCatchret.cpp ExecutionDomainFix.cpp ExpandMemCmp.cpp ExpandPostRAPseudos.cpp diff --git a/llvm/lib/CodeGen/EHContGuardCatchret.cpp b/llvm/lib/CodeGen/EHContGuardCatchret.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/CodeGen/EHContGuardCatchret.cpp @@ -0,0 +1,85 @@ +//===-- EHContGuardCatchret.cpp - Catchret target symbols -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains a machine function pass to insert a symbol before each +/// valid catchret target and store this in the MachineFunction's +/// CatchRetTargets vector. This will be used to emit the table of valid targets +/// used by EHCont Guard. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Statistic.h" +#include "llvm/CodeGen/MachineBasicBlock.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/CodeGen/MachineOperand.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/InitializePasses.h" + +using namespace llvm; + +#define DEBUG_TYPE "ehcontguard-catchret" + +STATISTIC(EHContGuardCatchretTargets, + "Number of EHCont Guard catchret targets"); + +namespace { + +/// MachineFunction pass to insert a symbol before each valid catchret target +/// and store these in the MachineFunction's CatchRetTargets vector. +class EHContGuardCatchret : public MachineFunctionPass { +public: + static char ID; + + EHContGuardCatchret() : MachineFunctionPass(ID) { + initializeEHContGuardCatchretPass(*PassRegistry::getPassRegistry()); + } + + StringRef getPassName() const override { + return "EH Cont Guard catchret targets"; + } + + bool runOnMachineFunction(MachineFunction &MF) override; +}; + +} // end anonymous namespace + +char EHContGuardCatchret::ID = 0; + +INITIALIZE_PASS(EHContGuardCatchret, "EHContGuardCatchret", + "Insert symbols at valid catchret targets for /guard:ehcont", + false, false) +FunctionPass *llvm::createEHContGuardCatchretPass() { + return new EHContGuardCatchret(); +} + +bool EHContGuardCatchret::runOnMachineFunction(MachineFunction &MF) { + + // Skip modules for which the ehcontguard flag is not set. + if (!MF.getMMI().getModule()->getModuleFlag("ehcontguard")) + return false; + + // Skip functions that do not have catchret + if (!MF.hasEHCatchRet()) + return false; + + bool Result = false; + + for (MachineBasicBlock &MBB : MF) { + if (MBB.isEHCatchRetTarget() && !MBB.isEHFuncletEntry()) { + MBB.setIsEHContTarget(); + MF.addCatchretTarget(MBB.getEHContSymbol()); + EHContGuardCatchretTargets++; + Result = true; + } + } + + return Result; +} diff --git a/llvm/lib/CodeGen/MachineBasicBlock.cpp b/llvm/lib/CodeGen/MachineBasicBlock.cpp --- a/llvm/lib/CodeGen/MachineBasicBlock.cpp +++ b/llvm/lib/CodeGen/MachineBasicBlock.cpp @@ -87,6 +87,17 @@ return CachedMCSymbol; } +MCSymbol *MachineBasicBlock::getEHContSymbol() const { + if (!CachedEHContMCSymbol) { + const MachineFunction *MF = getParent(); + SmallString<128> SymbolName; + raw_svector_ostream(SymbolName) + << "$ehgcr_" << MF->getFunctionNumber() << '_' << getNumber(); + CachedEHContMCSymbol = MF->getContext().getOrCreateSymbol(SymbolName); + } + return CachedEHContMCSymbol; +} + MCSymbol *MachineBasicBlock::getEndSymbol() const { if (!CachedEndMCSymbol) { const MachineFunction *MF = getParent(); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -1592,6 +1592,8 @@ // Update machine-CFG edge. MachineBasicBlock *TargetMBB = FuncInfo.MBBMap[I.getSuccessor()]; FuncInfo.MBB->addSuccessor(TargetMBB); + TargetMBB->setIsEHCatchRetTarget(true); + DAG.getMachineFunction().setHasEHCatchRet(true); auto Pers = classifyEHPersonality(FuncInfo.Fn->getPersonalityFn()); bool IsSEH = isAsynchronousEHPersonality(Pers); diff --git a/llvm/lib/MC/MCObjectFileInfo.cpp b/llvm/lib/MC/MCObjectFileInfo.cpp --- a/llvm/lib/MC/MCObjectFileInfo.cpp +++ b/llvm/lib/MC/MCObjectFileInfo.cpp @@ -753,6 +753,11 @@ SXDataSection = Ctx->getCOFFSection(".sxdata", COFF::IMAGE_SCN_LNK_INFO, SectionKind::getMetadata()); + GEHContSection = Ctx->getCOFFSection(".gehcont$y", + COFF::IMAGE_SCN_CNT_INITIALIZED_DATA | + COFF::IMAGE_SCN_MEM_READ, + SectionKind::getMetadata()); + GFIDsSection = Ctx->getCOFFSection(".gfids$y", COFF::IMAGE_SCN_CNT_INITIALIZED_DATA | COFF::IMAGE_SCN_MEM_READ, diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -191,7 +191,32 @@ } // end anonymous namespace void AArch64AsmPrinter::emitStartOfAsmFile(Module &M) { - if (!TM.getTargetTriple().isOSBinFormatELF()) + const Triple &TT = TM.getTargetTriple(); + + if (TT.isOSBinFormatCOFF()) { + // Emit an absolute @feat.00 symbol. This appears to be some kind of + // compiler features bitfield read by link.exe. + MCSymbol *S = MMI->getContext().getOrCreateSymbol(StringRef("@feat.00")); + OutStreamer->BeginCOFFSymbolDef(S); + OutStreamer->EmitCOFFSymbolStorageClass(COFF::IMAGE_SYM_CLASS_STATIC); + OutStreamer->EmitCOFFSymbolType(COFF::IMAGE_SYM_DTYPE_NULL); + OutStreamer->EndCOFFSymbolDef(); + int64_t Feat00Flags = 0; + + if (M.getModuleFlag("cfguard")) { + Feat00Flags |= 0x800; // Object is CFG-aware. + } + + if (M.getModuleFlag("ehcontguard")) { + Feat00Flags |= 0x4000; // Object also has EHCont. + } + + OutStreamer->emitSymbolAttribute(S, MCSA_Global); + OutStreamer->emitAssignment( + S, MCConstantExpr::create(Feat00Flags, MMI->getContext())); + } + + if (!TT.isOSBinFormatELF()) return; // Assemble feature flags that may require creation of a note section. diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp @@ -676,9 +676,12 @@ if (BranchRelaxation) addPass(&BranchRelaxationPassID); - // Identify valid longjmp targets for Windows Control Flow Guard. - if (TM->getTargetTriple().isOSWindows()) + if (TM->getTargetTriple().isOSWindows()) { + // Identify valid longjmp targets for Windows Control Flow Guard. addPass(createCFGuardLongjmpPass()); + // Identify valid eh continuation targets for Windows EHCont Guard. + addPass(createEHContGuardCatchretPass()); + } if (TM->getOptLevel() != CodeGenOpt::None && EnableCompressJumpTables) addPass(createAArch64CompressJumpTablesPass()); diff --git a/llvm/lib/Target/ARM/ARMTargetMachine.cpp b/llvm/lib/Target/ARM/ARMTargetMachine.cpp --- a/llvm/lib/Target/ARM/ARMTargetMachine.cpp +++ b/llvm/lib/Target/ARM/ARMTargetMachine.cpp @@ -564,7 +564,10 @@ addPass(createARMConstantIslandPass()); addPass(createARMLowOverheadLoopsPass()); - // Identify valid longjmp targets for Windows Control Flow Guard. - if (TM->getTargetTriple().isOSWindows()) + if (TM->getTargetTriple().isOSWindows()) { + // Identify valid longjmp targets for Windows Control Flow Guard. addPass(createCFGuardLongjmpPass()); + // Identify valid eh continuation targets for Windows EHCont Guard. + addPass(createEHContGuardCatchretPass()); + } } diff --git a/llvm/lib/Target/X86/X86AsmPrinter.cpp b/llvm/lib/Target/X86/X86AsmPrinter.cpp --- a/llvm/lib/Target/X86/X86AsmPrinter.cpp +++ b/llvm/lib/Target/X86/X86AsmPrinter.cpp @@ -683,8 +683,13 @@ Feat00Flags |= 1; } - if (M.getModuleFlag("cfguard")) + if (M.getModuleFlag("cfguard")) { Feat00Flags |= 0x800; // Object is CFG-aware. + } + + if (M.getModuleFlag("ehcontguard")) { + Feat00Flags |= 0x4000; // Object also has EHCont. + } OutStreamer->emitSymbolAttribute(S, MCSA_Global); OutStreamer->emitAssignment( diff --git a/llvm/lib/Target/X86/X86TargetMachine.cpp b/llvm/lib/Target/X86/X86TargetMachine.cpp --- a/llvm/lib/Target/X86/X86TargetMachine.cpp +++ b/llvm/lib/Target/X86/X86TargetMachine.cpp @@ -568,9 +568,13 @@ (!TT.isOSWindows() || MAI->getExceptionHandlingType() == ExceptionHandling::DwarfCFI)) addPass(createCFIInstrInserter()); - // Identify valid longjmp targets for Windows Control Flow Guard. - if (TT.isOSWindows()) + + if (TT.isOSWindows()) { + // Identify valid longjmp targets for Windows Control Flow Guard. addPass(createCFGuardLongjmpPass()); + // Identify valid eh continuation targets for Windows EHCont Guard. + addPass(createEHContGuardCatchretPass()); + } addPass(createX86LoadValueInjectionRetHardeningPass()); } diff --git a/llvm/test/CodeGen/AArch64/ehcontguard.ll b/llvm/test/CodeGen/AArch64/ehcontguard.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/ehcontguard.ll @@ -0,0 +1,29 @@ +; RUN: llc < %s -mtriple=aarch64-windows | FileCheck %s +; EHCont Guard is currently only available on Windows + +; CHECK: .set @feat.00, 16384 + +; CHECK: .section .gehcont$y + +define dso_local void @"?func1@@YAXXZ"() #0 personality i8* bitcast (i32 (...)* @__CxxFrameHandler3 to i8*) { +entry: + invoke void @"?func2@@YAXXZ"() + to label %invoke.cont unwind label %catch.dispatch +catch.dispatch: ; preds = %entry + %0 = catchswitch within none [label %catch] unwind to caller +catch: ; preds = %catch.dispatch + %1 = catchpad within %0 [i8* null, i32 64, i8* null] + catchret from %1 to label %catchret.dest +catchret.dest: ; preds = %catch + br label %try.cont +try.cont: ; preds = %catchret.dest, %invoke.cont + ret void +invoke.cont: ; preds = %entry + br label %try.cont +} + +declare dso_local void @"?func2@@YAXXZ"() #1 +declare dso_local i32 @__CxxFrameHandler3(...) + +!llvm.module.flags = !{!0} +!0 = !{i32 1, !"ehcontguard", i32 1} diff --git a/llvm/test/CodeGen/X86/ehcontguard.ll b/llvm/test/CodeGen/X86/ehcontguard.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/ehcontguard.ll @@ -0,0 +1,29 @@ +; RUN: llc < %s -mtriple=x86_64-pc-windows-msvc | FileCheck %s +; EHCont Guard is currently only available on Windows + +; CHECK: .set @feat.00, 16384 + +; CHECK: .section .gehcont$y + +define dso_local void @"?func1@@YAXXZ"() #0 personality i8* bitcast (i32 (...)* @__CxxFrameHandler3 to i8*) { +entry: + invoke void @"?func2@@YAXXZ"() + to label %invoke.cont unwind label %catch.dispatch +catch.dispatch: ; preds = %entry + %0 = catchswitch within none [label %catch] unwind to caller +catch: ; preds = %catch.dispatch + %1 = catchpad within %0 [i8* null, i32 64, i8* null] + catchret from %1 to label %catchret.dest +catchret.dest: ; preds = %catch + br label %try.cont +try.cont: ; preds = %catchret.dest, %invoke.cont + ret void +invoke.cont: ; preds = %entry + br label %try.cont +} + +declare dso_local void @"?func2@@YAXXZ"() #1 +declare dso_local i32 @__CxxFrameHandler3(...) + +!llvm.module.flags = !{!0} +!0 = !{i32 1, !"ehcontguard", i32 1} diff --git a/llvm/utils/gn/secondary/llvm/lib/CodeGen/BUILD.gn b/llvm/utils/gn/secondary/llvm/lib/CodeGen/BUILD.gn --- a/llvm/utils/gn/secondary/llvm/lib/CodeGen/BUILD.gn +++ b/llvm/utils/gn/secondary/llvm/lib/CodeGen/BUILD.gn @@ -42,6 +42,7 @@ "DwarfEHPrepare.cpp", "EarlyIfConversion.cpp", "EdgeBundles.cpp", + "EHContGuardCatchret.cpp", "ExecutionDomainFix.cpp", "ExpandMemCmp.cpp", "ExpandPostRAPseudos.cpp",