diff --git a/llvm/lib/Target/BPF/BPF.h b/llvm/lib/Target/BPF/BPF.h --- a/llvm/lib/Target/BPF/BPF.h +++ b/llvm/lib/Target/BPF/BPF.h @@ -15,6 +15,7 @@ namespace llvm { class BPFTargetMachine; +ModulePass *createBPFAdjustOpt(); ModulePass *createBPFCheckAndAdjustIR(); FunctionPass *createBPFAbstractMemberAccess(BPFTargetMachine *TM); @@ -27,7 +28,7 @@ FunctionPass *createBPFMIPreEmitCheckingPass(); void initializeBPFCheckAndAdjustIRPass(PassRegistry&); - +void initializeBPFAdjustOptPass(PassRegistry&); void initializeBPFAbstractMemberAccessPass(PassRegistry&); void initializeBPFPreserveDITypePass(PassRegistry&); void initializeBPFMISimplifyPatchablePass(PassRegistry&); diff --git a/llvm/lib/Target/BPF/BPFAdjustOpt.cpp b/llvm/lib/Target/BPF/BPFAdjustOpt.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/BPF/BPFAdjustOpt.cpp @@ -0,0 +1,287 @@ +//===---------------- BPFAdjustOpt.cpp - Adjust Optimization --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Adjust optimization to make the code more kernel verifier friendly. +// +//===----------------------------------------------------------------------===// + +#include "BPF.h" +#include "BPFCORE.h" +#include "BPFTargetMachine.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +#define DEBUG_TYPE "bpf-adjust-opt" + +using namespace llvm; + +namespace { + +class BPFAdjustOpt final : public ModulePass { + struct PassThroughInfo { + Instruction *Input; + Instruction *UsedInst; + uint32_t OpIdx; + PassThroughInfo(Instruction *I, Instruction *U, uint32_t Idx) + : Input(I), UsedInst(U), OpIdx(Idx) {} + }; + +public: + static char ID; + Module *Mod; + + BPFAdjustOpt() : ModulePass(ID) {} + bool runOnModule(Module &M) override; + +private: + SmallVector PassThroughs; + + void adjustBasicBlock(BasicBlock &BB); + bool serializeICMPCrossBB(BasicBlock &BB); + void adjustInst(Instruction &I); + bool serializeICMPInBB(Instruction &I); + bool avoidSpeculation(Instruction &I); + bool insertPassThrough(); +}; + +} // End anonymous namespace + +char BPFAdjustOpt::ID = 0; +INITIALIZE_PASS(BPFAdjustOpt, "bpf-adjust-opt", "BPF Adjust Optimization", + false, false) + +ModulePass *llvm::createBPFAdjustOpt() { return new BPFAdjustOpt(); } + +bool BPFAdjustOpt::runOnModule(Module &M) { + Mod = &M; + for (Function &F : M) + for (auto &BB : F) { + adjustBasicBlock(BB); + for (auto &I : BB) + adjustInst(I); + } + + return insertPassThrough(); +} + +bool BPFAdjustOpt::insertPassThrough() { + for (auto &Info : PassThroughs) { + auto *CI = BPFCoreSharedInfo::insertPassThrough( + Mod, Info.UsedInst->getParent(), Info.Input, Info.UsedInst); + Info.UsedInst->setOperand(Info.OpIdx, CI); + } + + return !PassThroughs.empty(); +} + +// To avoid combining conditionals in the same basic block by +// instrcombine optimization. +bool BPFAdjustOpt::serializeICMPInBB(Instruction &I) { + // For: + // comp1 = icmp ...; + // comp2 = icmp ...; + // ... or comp1 comp2 ... + // changed to: + // comp1 = icmp ...; + // comp2 = icmp ...; + // new_comp1 = __builtin_bpf_passthrough(seq_num, comp1) + // ... or new_comp1 comp2 ... + if (I.getOpcode() != Instruction::Or) + return false; + auto *Icmp1 = dyn_cast(I.getOperand(0)); + if (!Icmp1) + return false; + auto *Icmp2 = dyn_cast(I.getOperand(1)); + if (!Icmp2) + return false; + + Value *Icmp1Op0 = Icmp1->getOperand(0); + Value *Icmp2Op0 = Icmp2->getOperand(0); + if (Icmp1Op0 != Icmp2Op0) + return false; + + // Now we got two icmp instructions which feed into + // an "or" instruction. + PassThroughInfo Info(Icmp1, &I, 0); + PassThroughs.push_back(Info); + return true; +} + +// To avoid combining conditionals in the same basic block by +// instrcombine optimization. +bool BPFAdjustOpt::serializeICMPCrossBB(BasicBlock &BB) { + // For: + // B1: + // comp1 = icmp ...; + // if (comp1) goto B2 else B3; + // B2: + // comp2 = icmp ...; + // if (comp2) goto B4 else B5; + // B4: + // ... + // changed to: + // B1: + // comp1 = icmp ...; + // comp1 = __builtin_bpf_passthrough(seq_num, comp1); + // if (comp1) goto B2 else B3; + // B2: + // comp2 = icmp ...; + // if (comp2) goto B4 else B5; + // B4: + // ... + + // Check basic predecessors, if two of them (say B1, B2) are using + // icmp instructions to generate conditions and one is the predesessor + // of another (e.g., B1 is the predecessor of B2). Add a passthrough + // barrier after icmp inst of block B1. + BasicBlock *B2 = BB.getSinglePredecessor(); + if (!B2) + return false; + + BasicBlock *B1 = B2->getSinglePredecessor(); + if (!B1) + return false; + + Instruction *TI = B2->getTerminator(); + auto *BI = dyn_cast(TI); + if (!BI || !BI->isConditional()) + return false; + auto *Cond = dyn_cast(BI->getCondition()); + if (!Cond || B2->getFirstNonPHI() != Cond) + return false; + Value *B2Op0 = Cond->getOperand(0); + auto Cond2Op = Cond->getPredicate(); + + TI = B1->getTerminator(); + BI = dyn_cast(TI); + if (!BI || !BI->isConditional()) + return false; + Cond = dyn_cast(BI->getCondition()); + if (!Cond) + return false; + Value *B1Op0 = Cond->getOperand(0); + auto Cond1Op = Cond->getPredicate(); + + if (B1Op0 != B2Op0) + return false; + + if (Cond1Op == ICmpInst::ICMP_SGT || Cond1Op == ICmpInst::ICMP_SGE) { + if (Cond2Op != ICmpInst::ICMP_SLT && Cond1Op != ICmpInst::ICMP_SLE) + return false; + } else if (Cond1Op == ICmpInst::ICMP_SLT || Cond1Op == ICmpInst::ICMP_SLE) { + if (Cond2Op != ICmpInst::ICMP_SGT && Cond1Op != ICmpInst::ICMP_SGE) + return false; + } else { + return false; + } + + PassThroughInfo Info(Cond, BI, 0); + PassThroughs.push_back(Info); + + return true; +} + +// To avoid speculative hoisting certain computations out of +// a basic block. +bool BPFAdjustOpt::avoidSpeculation(Instruction &I) { + if (auto *LdInst = dyn_cast(&I)) { + if (auto *GV = dyn_cast(LdInst->getOperand(0))) { + if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) || + GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr)) + return false; + } + } + + if (!dyn_cast(&I) && !dyn_cast(&I)) + return false; + + // For: + // B1: + // var = ... + // comp1 = icmp var, ; + // if (comp1) goto B2 else B3; + // B2: + // ... var ... + // change to: + // B1: + // var = ... + // comp1 = icmp var, ; + // if (comp1) goto B2 else B3; + // B2: + // var = __builtin_bpf_passthrough(seq_num, var); + // ... var ... + bool isCandidate = false; + SmallVector Candidates; + for (User *U : I.users()) { + Instruction *Inst = dyn_cast(U); + if (!Inst) + continue; + + // FIXME: not really precise + if (auto *Icmp1 = dyn_cast(Inst)) { + Value *Icmp1Op1 = Icmp1->getOperand(1); + if (!isa(Icmp1Op1)) + return false; + isCandidate = true; + continue; + } + + if (Inst->getParent() == I.getParent()) + continue; + + for (auto &I2 : *Inst->getParent()) { + if (dyn_cast(&I2)) + return false; + if (dyn_cast(&I2) || dyn_cast(&I2)) + return false; + if (&I2 == Inst) + break; + } + + // use in a different basic block + // only handle GEP now. + if (auto *GI = dyn_cast(Inst)) { + // traverse GEP inst to find Use operand index + unsigned i, e; + for (i = 1, e = GI->getNumOperands(); i != e; ++i) { + Value *V = GI->getOperand(i); + if (V == &I) + break; + } + if (i == e) + continue; + + PassThroughInfo Info(&I, GI, i); + Candidates.push_back(Info); + } + } + + if (!isCandidate || Candidates.empty()) + return false; + + PassThroughs.insert(PassThroughs.end(), Candidates.begin(), Candidates.end()); + return true; +} + +void BPFAdjustOpt::adjustBasicBlock(BasicBlock &BB) { + if (serializeICMPCrossBB(BB)) + return; +} + +void BPFAdjustOpt::adjustInst(Instruction &I) { + if (serializeICMPInBB(I)) + return; + if (avoidSpeculation(I)) + return; +} diff --git a/llvm/lib/Target/BPF/BPFTargetMachine.cpp b/llvm/lib/Target/BPF/BPFTargetMachine.cpp --- a/llvm/lib/Target/BPF/BPFTargetMachine.cpp +++ b/llvm/lib/Target/BPF/BPFTargetMachine.cpp @@ -37,6 +37,7 @@ RegisterTargetMachine Z(getTheBPFTarget()); PassRegistry &PR = *PassRegistry::getPassRegistry(); + initializeBPFAdjustOptPass(PR); initializeBPFAbstractMemberAccessPass(PR); initializeBPFPreserveDITypePass(PR); initializeBPFCheckAndAdjustIRPass(PR); @@ -112,10 +113,16 @@ PM.add(createCFGSimplificationPass( SimplifyCFGOptions().hoistCommonInsts(true))); }); + Builder.addExtension( + PassManagerBuilder::EP_ModuleOptimizerEarly, + [&](const PassManagerBuilder &, legacy::PassManagerBase &PM) { + PM.add(createBPFAdjustOpt()); + }); } void BPFPassConfig::addIRPasses() { addPass(createBPFCheckAndAdjustIR()); + TargetPassConfig::addIRPasses(); } diff --git a/llvm/lib/Target/BPF/CMakeLists.txt b/llvm/lib/Target/BPF/CMakeLists.txt --- a/llvm/lib/Target/BPF/CMakeLists.txt +++ b/llvm/lib/Target/BPF/CMakeLists.txt @@ -14,6 +14,7 @@ add_llvm_target(BPFCodeGen BPFAbstractMemberAccess.cpp + BPFAdjustOpt.cpp BPFAsmPrinter.cpp BPFCheckAndAdjustIR.cpp BPFFrameLowering.cpp