diff --git a/llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp b/llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp --- a/llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp +++ b/llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp @@ -42,6 +42,7 @@ #include "X86TargetMachine.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" @@ -104,9 +105,9 @@ cl::init(false), cl::Hidden); static llvm::sys::DynamicLibrary OptimizeDL; -typedef int (*OptimizeCutT)(unsigned int *nodes, unsigned int nodes_size, - unsigned int *edges, int *edge_values, - int *cut_edges /* out */, unsigned int edges_size); +typedef int (*OptimizeCutT)(unsigned int *Nodes, unsigned int NodesSize, + unsigned int *Edges, int *EdgeValues, + int *CutEdges /* out */, unsigned int EdgesSize); static OptimizeCutT OptimizeCut = nullptr; namespace { @@ -148,9 +149,10 @@ private: using GraphBuilder = ImmutableGraphBuilder; + using Edge = MachineGadgetGraph::Edge; + using Node = MachineGadgetGraph::Node; using EdgeSet = MachineGadgetGraph::EdgeSet; using NodeSet = MachineGadgetGraph::NodeSet; - using Gadget = std::pair; const X86Subtarget *STI; const TargetInstrInfo *TII; @@ -162,8 +164,8 @@ const MachineDominanceFrontier &MDF) const; int hardenLoadsWithPlugin(MachineFunction &MF, std::unique_ptr Graph) const; - int hardenLoadsWithGreedyHeuristic( - MachineFunction &MF, std::unique_ptr Graph) const; + int hardenLoadsWithHeuristic(MachineFunction &MF, + std::unique_ptr Graph) const; int elimMitigatedEdgesAndNodes(MachineGadgetGraph &G, EdgeSet &ElimEdges /* in, out */, NodeSet &ElimNodes /* in, out */) const; @@ -198,7 +200,7 @@ using ChildIteratorType = typename Traits::ChildIteratorType; using ChildEdgeIteratorType = typename Traits::ChildEdgeIteratorType; - DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {} + DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {} std::string getNodeLabel(NodeRef Node, GraphType *) { if (Node->getValue() == MachineGadgetGraph::ArgNodeSentinel) @@ -243,7 +245,7 @@ AU.setPreservesCFG(); } -static void WriteGadgetGraph(raw_ostream &OS, MachineFunction &MF, +static void writeGadgetGraph(raw_ostream &OS, MachineFunction &MF, MachineGadgetGraph *G) { WriteGraph(OS, G, /*ShortNames*/ false, "Speculative gadgets for \"" + MF.getName() + "\" function"); @@ -279,7 +281,7 @@ return false; // didn't find any gadgets if (EmitDotVerify) { - WriteGadgetGraph(outs(), MF, Graph.get()); + writeGadgetGraph(outs(), MF, Graph.get()); return false; } @@ -292,7 +294,7 @@ raw_fd_ostream FileOut(FileName, FileError); if (FileError) errs() << FileError.message(); - WriteGadgetGraph(FileOut, MF, Graph.get()); + writeGadgetGraph(FileOut, MF, Graph.get()); FileOut.close(); LLVM_DEBUG(dbgs() << "Emitting gadget graph... Done\n"); if (EmitDotOnly) @@ -313,7 +315,7 @@ } FencesInserted = hardenLoadsWithPlugin(MF, std::move(Graph)); } else { // Use the default greedy heuristic - FencesInserted = hardenLoadsWithGreedyHeuristic(MF, std::move(Graph)); + FencesInserted = hardenLoadsWithHeuristic(MF, std::move(Graph)); } if (FencesInserted > 0) @@ -540,17 +542,17 @@ // Returns the number of remaining gadget edges that could not be eliminated int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes( - MachineGadgetGraph &G, MachineGadgetGraph::EdgeSet &ElimEdges /* in, out */, - MachineGadgetGraph::NodeSet &ElimNodes /* in, out */) const { + MachineGadgetGraph &G, EdgeSet &ElimEdges /* in, out */, + NodeSet &ElimNodes /* in, out */) const { if (G.NumFences > 0) { // Eliminate fences and CFG edges that ingress and egress the fence, as // they are trivially mitigated. - for (const auto &E : G.edges()) { - const MachineGadgetGraph::Node *Dest = E.getDest(); + for (const Edge &E : G.edges()) { + const Node *Dest = E.getDest(); if (isFence(Dest->getValue())) { ElimNodes.insert(*Dest); ElimEdges.insert(E); - for (const auto &DE : Dest->edges()) + for (const Edge &DE : Dest->edges()) ElimEdges.insert(DE); } } @@ -558,29 +560,28 @@ // Find and eliminate gadget edges that have been mitigated. int MitigatedGadgets = 0, RemainingGadgets = 0; - MachineGadgetGraph::NodeSet ReachableNodes{G}; - for (const auto &RootN : G.nodes()) { + NodeSet ReachableNodes{G}; + for (const Node &RootN : G.nodes()) { if (llvm::none_of(RootN.edges(), MachineGadgetGraph::isGadgetEdge)) continue; // skip this node if it isn't a gadget source // Find all of the nodes that are CFG-reachable from RootN using DFS ReachableNodes.clear(); - std::function - FindReachableNodes = - [&](const MachineGadgetGraph::Node *N, bool FirstNode) { - if (!FirstNode) - ReachableNodes.insert(*N); - for (const auto &E : N->edges()) { - const MachineGadgetGraph::Node *Dest = E.getDest(); - if (MachineGadgetGraph::isCFGEdge(E) && - !ElimEdges.contains(E) && !ReachableNodes.contains(*Dest)) - FindReachableNodes(Dest, false); - } - }; + std::function FindReachableNodes = + [&](const Node *N, bool FirstNode) { + if (!FirstNode) + ReachableNodes.insert(*N); + for (const Edge &E : N->edges()) { + const Node *Dest = E.getDest(); + if (MachineGadgetGraph::isCFGEdge(E) && !ElimEdges.contains(E) && + !ReachableNodes.contains(*Dest)) + FindReachableNodes(Dest, false); + } + }; FindReachableNodes(&RootN, true); // Any gadget whose sink is unreachable has been mitigated - for (const auto &E : RootN.edges()) { + for (const Edge &E : RootN.edges()) { if (MachineGadgetGraph::isGadgetEdge(E)) { if (ReachableNodes.contains(*E.getDest())) { // This gadget's sink is reachable @@ -598,8 +599,8 @@ std::unique_ptr X86LoadValueInjectionLoadHardeningPass::trimMitigatedEdges( std::unique_ptr Graph) const { - MachineGadgetGraph::NodeSet ElimNodes{*Graph}; - MachineGadgetGraph::EdgeSet ElimEdges{*Graph}; + NodeSet ElimNodes{*Graph}; + EdgeSet ElimEdges{*Graph}; int RemainingGadgets = elimMitigatedEdgesAndNodes(*Graph, ElimEdges, ElimNodes); if (ElimEdges.empty() && ElimNodes.empty()) { @@ -630,11 +631,11 @@ auto Edges = std::make_unique(Graph->edges_size()); auto EdgeCuts = std::make_unique(Graph->edges_size()); auto EdgeValues = std::make_unique(Graph->edges_size()); - for (const auto &N : Graph->nodes()) { + for (const Node &N : Graph->nodes()) { Nodes[Graph->getNodeIndex(N)] = Graph->getEdgeIndex(*N.edges_begin()); } Nodes[Graph->nodes_size()] = Graph->edges_size(); // terminator node - for (const auto &E : Graph->edges()) { + for (const Edge &E : Graph->edges()) { Edges[Graph->getEdgeIndex(E)] = Graph->getNodeIndex(*E.getDest()); EdgeValues[Graph->getEdgeIndex(E)] = E.getValue(); } @@ -651,74 +652,67 @@ LLVM_DEBUG(dbgs() << "Inserting LFENCEs... Done\n"); LLVM_DEBUG(dbgs() << "Inserted " << FencesInserted << " fences\n"); - Graph = GraphBuilder::trim(*Graph, MachineGadgetGraph::NodeSet{*Graph}, - CutEdges); + Graph = GraphBuilder::trim(*Graph, NodeSet{*Graph}, CutEdges); } while (true); return FencesInserted; } -int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithGreedyHeuristic( +int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithHeuristic( MachineFunction &MF, std::unique_ptr Graph) const { - LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n"); - Graph = trimMitigatedEdges(std::move(Graph)); - LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n"); + // If `MF` does not have any fences, then no gadgets would have been + // mitigated at this point. + if (Graph->NumFences > 0) { + LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n"); + Graph = trimMitigatedEdges(std::move(Graph)); + LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n"); + } + if (Graph->NumGadgets == 0) return 0; LLVM_DEBUG(dbgs() << "Cutting edges...\n"); - MachineGadgetGraph::NodeSet ElimNodes{*Graph}, GadgetSinks{*Graph}; - MachineGadgetGraph::EdgeSet ElimEdges{*Graph}, CutEdges{*Graph}; - auto IsCFGEdge = [&ElimEdges, &CutEdges](const MachineGadgetGraph::Edge &E) { - return !ElimEdges.contains(E) && !CutEdges.contains(E) && - MachineGadgetGraph::isCFGEdge(E); - }; - auto IsGadgetEdge = [&ElimEdges, - &CutEdges](const MachineGadgetGraph::Edge &E) { - return !ElimEdges.contains(E) && !CutEdges.contains(E) && - MachineGadgetGraph::isGadgetEdge(E); - }; - - // FIXME: this is O(E^2), we could probably do better. - do { - // Find the cheapest CFG edge that will eliminate a gadget (by being - // egress from a SOURCE node or ingress to a SINK node), and cut it. - const MachineGadgetGraph::Edge *CheapestSoFar = nullptr; - - // First, collect all gadget source and sink nodes. - MachineGadgetGraph::NodeSet GadgetSources{*Graph}, GadgetSinks{*Graph}; - for (const auto &N : Graph->nodes()) { - if (ElimNodes.contains(N)) + EdgeSet CutEdges{*Graph}; + + // Begin by collecting all ingress CFG edges for each node + DenseMap> IngressEdgeMap; + for (const Edge &E : Graph->edges()) + if (MachineGadgetGraph::isCFGEdge(E)) + IngressEdgeMap[E.getDest()].push_back(&E); + + // For each gadget edge, make cuts that guarantee the gadget will be + // mitigated. A computationally efficient way to achieve this is to either: + // (a) cut all egress CFG edges from the gadget source, or + // (b) cut all ingress CFG edges to the gadget sink. + // + // Moreover, the algorithm tries not to make a cut into a loop by preferring + // to make a (b)-type cut if the gadget source resides at a greater loop depth + // than the gadget sink, or an (a)-type cut otherwise. + for (const Node &N : Graph->nodes()) { + for (const Edge &E : N.edges()) { + if (!MachineGadgetGraph::isGadgetEdge(E)) continue; - for (const auto &E : N.edges()) { - if (IsGadgetEdge(E)) { - GadgetSources.insert(N); - GadgetSinks.insert(*E.getDest()); - } - } - } - // Next, look for the cheapest CFG edge which, when cut, is guaranteed to - // mitigate at least one gadget by either: - // (a) being egress from a gadget source, or - // (b) being ingress to a gadget sink. - for (const auto &N : Graph->nodes()) { - if (ElimNodes.contains(N)) - continue; - for (const auto &E : N.edges()) { - if (IsCFGEdge(E)) { - if (GadgetSources.contains(N) || GadgetSinks.contains(*E.getDest())) { - if (!CheapestSoFar || E.getValue() < CheapestSoFar->getValue()) - CheapestSoFar = &E; - } - } - } + SmallVector EgressEdges; + SmallVector &IngressEdges = IngressEdgeMap[E.getDest()]; + for (const Edge &EgressEdge : N.edges()) + if (MachineGadgetGraph::isCFGEdge(EgressEdge)) + EgressEdges.push_back(&EgressEdge); + + int EgressCutCost = 0, IngressCutCost = 0; + for (const Edge *EgressEdge : EgressEdges) + if (!CutEdges.contains(*EgressEdge)) + EgressCutCost += EgressEdge->getValue(); + for (const Edge *IngressEdge : IngressEdges) + if (!CutEdges.contains(*IngressEdge)) + IngressCutCost += IngressEdge->getValue(); + + auto &EdgesToCut = + IngressCutCost < EgressCutCost ? IngressEdges : EgressEdges; + for (const Edge *E : EdgesToCut) + CutEdges.insert(*E); } - - assert(CheapestSoFar && "Failed to cut an edge"); - CutEdges.insert(*CheapestSoFar); - ElimEdges.insert(*CheapestSoFar); - } while (elimMitigatedEdgesAndNodes(*Graph, ElimEdges, ElimNodes)); + } LLVM_DEBUG(dbgs() << "Cutting edges... Done\n"); LLVM_DEBUG(dbgs() << "Cut " << CutEdges.count() << " edges\n"); @@ -734,8 +728,8 @@ MachineFunction &MF, MachineGadgetGraph &G, EdgeSet &CutEdges /* in, out */) const { int FencesInserted = 0; - for (const auto &N : G.nodes()) { - for (const auto &E : N.edges()) { + for (const Node &N : G.nodes()) { + for (const Edge &E : N.edges()) { if (CutEdges.contains(E)) { MachineInstr *MI = N.getValue(), *Prev; MachineBasicBlock *MBB; // Insert an LFENCE in this MBB @@ -751,7 +745,7 @@ Prev = MI->getPrevNode(); // Remove all egress CFG edges from this branch because the inserted // LFENCE prevents gadgets from crossing the branch. - for (const auto &E : N.edges()) { + for (const Edge &E : N.edges()) { if (MachineGadgetGraph::isCFGEdge(E)) CutEdges.insert(E); }