Index: llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp =================================================================== --- llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp +++ llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp @@ -42,6 +42,7 @@ #include "X86TargetMachine.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" @@ -104,9 +105,9 @@ cl::init(false), cl::Hidden); static llvm::sys::DynamicLibrary OptimizeDL; -typedef int (*OptimizeCutT)(unsigned int *nodes, unsigned int nodes_size, - unsigned int *edges, int *edge_values, - int *cut_edges /* out */, unsigned int edges_size); +typedef int (*OptimizeCutT)(unsigned int *Nodes, unsigned int NodesSize, + unsigned int *Edges, int *EdgeValues, + int *CutEdges /* out */, unsigned int EdgesSize); static OptimizeCutT OptimizeCut = nullptr; namespace { @@ -162,8 +163,8 @@ const MachineDominanceFrontier &MDF) const; int hardenLoadsWithPlugin(MachineFunction &MF, std::unique_ptr Graph) const; - int hardenLoadsWithGreedyHeuristic( - MachineFunction &MF, std::unique_ptr Graph) const; + int hardenLoadsWithHeuristic(MachineFunction &MF, + std::unique_ptr Graph) const; int elimMitigatedEdgesAndNodes(MachineGadgetGraph &G, EdgeSet &ElimEdges /* in, out */, NodeSet &ElimNodes /* in, out */) const; @@ -198,7 +199,7 @@ using ChildIteratorType = typename Traits::ChildIteratorType; using ChildEdgeIteratorType = typename Traits::ChildEdgeIteratorType; - DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {} + DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {} std::string getNodeLabel(NodeRef Node, GraphType *) { if (Node->getValue() == MachineGadgetGraph::ArgNodeSentinel) @@ -243,7 +244,7 @@ AU.setPreservesCFG(); } -static void WriteGadgetGraph(raw_ostream &OS, MachineFunction &MF, +static void writeGadgetGraph(raw_ostream &OS, MachineFunction &MF, MachineGadgetGraph *G) { WriteGraph(OS, G, /*ShortNames*/ false, "Speculative gadgets for \"" + MF.getName() + "\" function"); @@ -279,7 +280,7 @@ return false; // didn't find any gadgets if (EmitDotVerify) { - WriteGadgetGraph(outs(), MF, Graph.get()); + writeGadgetGraph(outs(), MF, Graph.get()); return false; } @@ -292,7 +293,7 @@ raw_fd_ostream FileOut(FileName, FileError); if (FileError) errs() << FileError.message(); - WriteGadgetGraph(FileOut, MF, Graph.get()); + writeGadgetGraph(FileOut, MF, Graph.get()); FileOut.close(); LLVM_DEBUG(dbgs() << "Emitting gadget graph... Done\n"); if (EmitDotOnly) @@ -313,7 +314,7 @@ } FencesInserted = hardenLoadsWithPlugin(MF, std::move(Graph)); } else { // Use the default greedy heuristic - FencesInserted = hardenLoadsWithGreedyHeuristic(MF, std::move(Graph)); + FencesInserted = hardenLoadsWithHeuristic(MF, std::move(Graph)); } if (FencesInserted > 0) @@ -658,67 +659,63 @@ return FencesInserted; } -int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithGreedyHeuristic( +int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithHeuristic( MachineFunction &MF, std::unique_ptr Graph) const { - LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n"); - Graph = trimMitigatedEdges(std::move(Graph)); - LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n"); + using NodeRef = const MachineGadgetGraph::Node *; + using EdgeRef = const MachineGadgetGraph::Edge *; + + // If `MF` does not have any fences, then no gadgets would have been + // mitigated at this point. + if (Graph->NumFences > 0) { + LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n"); + Graph = trimMitigatedEdges(std::move(Graph)); + LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n"); + } + if (Graph->NumGadgets == 0) return 0; LLVM_DEBUG(dbgs() << "Cutting edges...\n"); - MachineGadgetGraph::NodeSet ElimNodes{*Graph}, GadgetSinks{*Graph}; - MachineGadgetGraph::EdgeSet ElimEdges{*Graph}, CutEdges{*Graph}; - auto IsCFGEdge = [&ElimEdges, &CutEdges](const MachineGadgetGraph::Edge &E) { - return !ElimEdges.contains(E) && !CutEdges.contains(E) && - MachineGadgetGraph::isCFGEdge(E); - }; - auto IsGadgetEdge = [&ElimEdges, - &CutEdges](const MachineGadgetGraph::Edge &E) { - return !ElimEdges.contains(E) && !CutEdges.contains(E) && - MachineGadgetGraph::isGadgetEdge(E); - }; - - // FIXME: this is O(E^2), we could probably do better. - do { - // Find the cheapest CFG edge that will eliminate a gadget (by being - // egress from a SOURCE node or ingress to a SINK node), and cut it. - const MachineGadgetGraph::Edge *CheapestSoFar = nullptr; - - // First, collect all gadget source and sink nodes. - MachineGadgetGraph::NodeSet GadgetSources{*Graph}, GadgetSinks{*Graph}; - for (const auto &N : Graph->nodes()) { - if (ElimNodes.contains(N)) + MachineGadgetGraph::EdgeSet CutEdges{*Graph}; + + // Begin by collecting all ingress CFG edges for each node + DenseMap> IngressEdgeMap; + for (const auto &E : Graph->edges()) + if (MachineGadgetGraph::isCFGEdge(E)) + IngressEdgeMap[E.getDest()].push_back(&E); + + // For each gadget edge, make cuts that guarantee the gadget will be + // mitigated. A computationally efficient way to achieve this is to either: + // (a) cut all egress CFG edges from the gadget source, or + // (b) cut all ingress CFG edges to the gadget sink. + // + // Moreover, the algorithm tries not to make a cut into a loop by preferring + // to make a (b)-type cut if the gadget source resides at a greater loop depth + // than the gadget sink, or an (a)-type cut otherwise. + for (const auto &N : Graph->nodes()) { + for (const auto &E : N.edges()) { + if (!MachineGadgetGraph::isGadgetEdge(E)) continue; - for (const auto &E : N.edges()) { - if (IsGadgetEdge(E)) { - GadgetSources.insert(N); - GadgetSinks.insert(*E.getDest()); - } - } - } - // Next, look for the cheapest CFG edge which, when cut, is guaranteed to - // mitigate at least one gadget by either: - // (a) being egress from a gadget source, or - // (b) being ingress to a gadget sink. - for (const auto &N : Graph->nodes()) { - if (ElimNodes.contains(N)) - continue; - for (const auto &E : N.edges()) { - if (IsCFGEdge(E)) { - if (GadgetSources.contains(N) || GadgetSinks.contains(*E.getDest())) { - if (!CheapestSoFar || E.getValue() < CheapestSoFar->getValue()) - CheapestSoFar = &E; - } - } - } + SmallVector EgressEdges; + SmallVector &IngressEdges = IngressEdgeMap[E.getDest()]; + for (const auto &EgressEdge : N.edges()) + if (MachineGadgetGraph::isCFGEdge(EgressEdge)) + EgressEdges.push_back(&EgressEdge); + + int EgressCutCost = 0, IngressCutCost = 0; + for (const auto *EgressEdge : EgressEdges) + if (!CutEdges.contains(*EgressEdge)) + EgressCutCost += EgressEdge->getValue(); + for (const auto *IngressEdge : IngressEdges) + if (!CutEdges.contains(*IngressEdge)) + IngressCutCost += IngressEdge->getValue(); + + auto &EdgesToCut = + IngressCutCost < EgressCutCost ? IngressEdges : EgressEdges; + llvm::for_each(EdgesToCut, [&](EdgeRef E) { CutEdges.insert(*E); }); } - - assert(CheapestSoFar && "Failed to cut an edge"); - CutEdges.insert(*CheapestSoFar); - ElimEdges.insert(*CheapestSoFar); - } while (elimMitigatedEdgesAndNodes(*Graph, ElimEdges, ElimNodes)); + } LLVM_DEBUG(dbgs() << "Cutting edges... Done\n"); LLVM_DEBUG(dbgs() << "Cut " << CutEdges.count() << " edges\n");