diff --git a/llvm/lib/Target/BPF/BPF.h b/llvm/lib/Target/BPF/BPF.h --- a/llvm/lib/Target/BPF/BPF.h +++ b/llvm/lib/Target/BPF/BPF.h @@ -15,6 +15,15 @@ namespace llvm { class BPFTargetMachine; +class AdjustOptPhases { +public: + enum { + BPF_EARLY_GEN_IR_OPT = 0, + BPF_EARLY_TGT_IR_OPT = 1, + }; +}; + +ModulePass *createBPFAdjustOpt(int Phase); ModulePass *createBPFAbstractMemberAccess(BPFTargetMachine *TM); ModulePass *createBPFPreserveDIType(); @@ -25,6 +34,7 @@ FunctionPass *createBPFMIPreEmitPeepholePass(); FunctionPass *createBPFMIPreEmitCheckingPass(); +void initializeBPFAdjustOptPass(PassRegistry&); void initializeBPFAbstractMemberAccessPass(PassRegistry&); void initializeBPFPreserveDITypePass(PassRegistry&); void initializeBPFMISimplifyPatchablePass(PassRegistry&); diff --git a/llvm/lib/Target/BPF/BPFAdjustOpt.cpp b/llvm/lib/Target/BPF/BPFAdjustOpt.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/BPF/BPFAdjustOpt.cpp @@ -0,0 +1,318 @@ +//===---------------- BPFAdjustOpt.cpp - Adjust Optimization --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Adjust optimization to make the code more kernel verifier friendly. +// +//===----------------------------------------------------------------------===// + +#include "BPF.h" +#include "BPFTargetMachine.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +#define DEBUG_TYPE "bpf-adjust-opt" + +using namespace llvm; + +namespace { + +class BPFAdjustOpt final : public ModulePass { + StringRef getPassName() const override { return "BPF Adjust Optimization"; } + + bool runOnModule(Module &M) override; + +public: + static char ID; + BPFAdjustOpt() : ModulePass(ID) { this->Phase = 0; } + BPFAdjustOpt(int Phase) : ModulePass(ID) { this->Phase = Phase; } + +private: + struct BarrierInfo { + Value *Val; + Instruction *Before; + Instruction *UsedInst; + uint32_t OpIdx; + BarrierInfo(Value *V, Instruction *B, Instruction *U, uint32_t Idx) : + Val(V), Before(B), UsedInst(U), OpIdx(Idx) {} + }; + + int Phase; + SmallVector Barriers; + + bool adjustOpt(Module &M); + void adjustBasicBlock(BasicBlock &BB); + void adjustInst(Instruction &I); + bool serializeICMPInBB(Instruction &I); + bool serializeICMPCrossBB(BasicBlock &BB); + bool avoidSpeculation(Instruction &I); + bool serializeSelectCond(Instruction &I); + void insertUserBarrier(BarrierInfo &Info); +}; +} // End anonymous namespace + +char BPFAdjustOpt::ID = 0; +INITIALIZE_PASS(BPFAdjustOpt, DEBUG_TYPE, "adjust optimization", false, false) + +ModulePass *llvm::createBPFAdjustOpt(int Phase) { return new BPFAdjustOpt(Phase); } + +bool BPFAdjustOpt::runOnModule(Module &M) { + LLVM_DEBUG(dbgs() << "******** BPF Adjust Optimization ********\n"); + + return adjustOpt(M); +} + +void BPFAdjustOpt::insertUserBarrier(BarrierInfo &Info) +{ + Value *V = Info.Val; + FunctionType *AsmFTy = + FunctionType::get(V->getType(), {V->getType()}, false); + InlineAsm *Asm = InlineAsm::get(AsmFTy, StringRef(""), StringRef("=r,0"), + /*hasSideEffects=*/true); + auto *CI = CallInst::Create(Asm, {V}, "", Info.Before); + Info.UsedInst->setOperand(Info.OpIdx, CI); +} + +// To avoid combining conditionals in the same basic block by +// instrcombine optimization. +bool BPFAdjustOpt::serializeICMPInBB(Instruction &I) { + // For: + // comp1 = icmp ...; + // comp2 = icmp ...; + // ... or comp1 comp2 ... + // changed to: + // comp1 = icmp ...; + // comp2 = icmp ...; + // new_comp1 = user_barrier(comp1) + // ... or new_comp1 comp2 ... + // user_barrier is an inline asm: + // call asm sideeffect "", "=r,0"(comp1) + if (I.getOpcode() != Instruction::Or) + return false; + auto *Icmp1 = dyn_cast(I.getOperand(0)); + if (!Icmp1) + return false; + auto *Icmp2 = dyn_cast(I.getOperand(1)); + if (!Icmp2) + return false; + + Value *Icmp1Op0 = Icmp1->getOperand(0); + Value *Icmp2Op0 = Icmp2->getOperand(0); + if (Icmp1Op0 != Icmp2Op0) + return false; + + // Now we got two icmp instructions which feed into + // an "or" instruction. + BarrierInfo Info(Icmp1, &I, &I, 0); + Barriers.push_back(Info); + return true; +} + +// To avoid combining conditionals in the same basic block by +// instrcombine optimization. +bool BPFAdjustOpt::serializeICMPCrossBB(BasicBlock &BB) { + // For: + // B1: + // comp1 = icmp ...; + // if (comp1) goto B2 else B3; + // B2: + // comp2 = icmp ...; + // if (comp2) goto B4 else B5; + // B4: + // ... + // changed to: + // B1: + // comp1 = icmp ...; + // comp1 = user_barrier(comp1); + // if (comp1) goto B2 else B3; + // B2: + // comp2 = icmp ...; + // if (comp2) goto B4 else B5; + // B4: + // ... + + // Check basic predecessors, if two of them (say B1, B2) are using + // icmp instructions to generate conditions and one is the predesessor + // of another (e.g., B1 is the predecessor of B2). Add a barrier after + // icmp inst of block B1. + BasicBlock *B2 = BB.getSinglePredecessor(); + if (!B2) + return false; + + BasicBlock *B1 = B2->getSinglePredecessor(); + if (!B1) + return false; + + Instruction *TI = B2->getTerminator(); + auto *BI = dyn_cast(TI); + if (!BI || !BI->isConditional()) + return false; + auto *Cond = dyn_cast(BI->getCondition()); + if (!Cond) + return false; + Value *B2Op0 = Cond->getOperand(0); + + TI = B1->getTerminator(); + BI = dyn_cast(TI); + if (!BI || !BI->isConditional()) + return false; + Cond = dyn_cast(BI->getCondition()); + if (!Cond) + return false; + Value *B1Op0 = Cond->getOperand(0); + + if (B1Op0 != B2Op0) + return false; + + // insert user barrier to B1. + BarrierInfo Info(Cond, BI, BI, 0); + Barriers.push_back(Info); + + return true; +} + +// To avoid speculative hoisting certain computations out of +// a basic block. +bool BPFAdjustOpt::avoidSpeculation(Instruction &I) { + // For: + // B1: + // var = ... + // comp1 = icmp var, ; + // if (comp1) goto B2 else B3; + // B2: + // ... var ... + // change to: + // B1: + // var = ... + // comp1 = icmp var, ; + // if (comp1) goto B2 else B3; + // B2: + // var = user_barrier(var); + // ... var ... + bool isCandidate = false; + SmallVector Candidates; + BasicBlock *B1 = I.getParent(); + for (User *U : I.users()) { + Instruction *Inst = dyn_cast(U); + if (!Inst) + continue; + + if (Inst->getParent() == B1) { + auto *Icmp1 = dyn_cast(Inst); + if (!Icmp1) + return false; + Value *Icmp1Op1 = Icmp1->getOperand(1); + if (!isa(Icmp1Op1)) + return false; + isCandidate = true; + continue; + } + + // use in a different basic block + // only handle GEP now. + if (auto *GI = dyn_cast(Inst)) { + // traverse GEP inst to find Use operand index + unsigned i, e; + for (i = 1, e = GI->getNumOperands(); i != e; ++i) { + Value *V = GI->getOperand(i); + if (V == &I) + break; + } + if (i == e) + continue; + + BarrierInfo Info(&I, GI, GI, i); + Candidates.push_back(Info); + } + } + + if (!isCandidate || Candidates.empty()) + return false; + + Barriers.insert(Barriers.end(), Candidates.begin(), Candidates.end()); + return true; +} + +// To lower Select here if the Select condition is used +// in multiple places. +bool BPFAdjustOpt::serializeSelectCond(Instruction &I) { + // For: + // icmp1 = ... + // ... + // a = select icmp1, val1, val2 + // ... + // ... icmp1 ... + // change to: + // icmp1 = ... + // user_barrier(icmp1) + // ... + // a = select icmp1, val1, val2 + // ... + // ... icmp1 ... + auto *Cond = dyn_cast(&I); + if (!Cond) + return false; + + SmallVector Candidates; + for (User *U : I.users()) { + Instruction *Inst = dyn_cast(U); + if (!Inst) + continue; + + // The use should be 'select' or 'br' + if (!isa(Inst) && !isa(Inst)) + return false; + + BarrierInfo Info(&I, Inst, Inst, 0); + Candidates.push_back(Info); + } + + if (Candidates.size() < 2) + return false; + + Barriers.insert(Barriers.end(), Candidates.begin(), Candidates.end()); + return true; +} + +void BPFAdjustOpt::adjustBasicBlock(BasicBlock &BB) { + if (Phase == AdjustOptPhases::BPF_EARLY_GEN_IR_OPT) { + if (serializeICMPCrossBB(BB)) + return; + } +} + +void BPFAdjustOpt::adjustInst(Instruction &I) { + if (Phase == AdjustOptPhases::BPF_EARLY_GEN_IR_OPT) { + if (serializeICMPInBB(I)) + return; + if (avoidSpeculation(I)) + return; + } else { + if (serializeSelectCond(I)) + return; + } +} + +bool BPFAdjustOpt::adjustOpt(Module &M) { + for (Function &F : M) + for (auto &BB : F) { + adjustBasicBlock(BB); + for (auto &I : BB) + adjustInst(I); + } + + for (auto &Info: Barriers) + insertUserBarrier(Info); + + return !Barriers.empty(); +} diff --git a/llvm/lib/Target/BPF/BPFTargetMachine.cpp b/llvm/lib/Target/BPF/BPFTargetMachine.cpp --- a/llvm/lib/Target/BPF/BPFTargetMachine.cpp +++ b/llvm/lib/Target/BPF/BPFTargetMachine.cpp @@ -37,6 +37,7 @@ RegisterTargetMachine Z(getTheBPFTarget()); PassRegistry &PR = *PassRegistry::getPassRegistry(); + initializeBPFAdjustOptPass(PR); initializeBPFAbstractMemberAccessPass(PR); initializeBPFPreserveDITypePass(PR); initializeBPFMIPeepholePass(PR); @@ -104,11 +105,17 @@ PM.add(createCFGSimplificationPass( SimplifyCFGOptions().hoistCommonInsts(true))); }); + Builder.addExtension( + PassManagerBuilder::EP_ModuleOptimizerEarly, + [&](const PassManagerBuilder &, legacy::PassManagerBase &PM) { + PM.add(createBPFAdjustOpt(AdjustOptPhases::BPF_EARLY_GEN_IR_OPT)); + }); } void BPFPassConfig::addIRPasses() { addPass(createBPFAbstractMemberAccess(&getBPFTargetMachine())); addPass(createBPFPreserveDIType()); + addPass(createBPFAdjustOpt(AdjustOptPhases::BPF_EARLY_TGT_IR_OPT)); TargetPassConfig::addIRPasses(); } diff --git a/llvm/lib/Target/BPF/CMakeLists.txt b/llvm/lib/Target/BPF/CMakeLists.txt --- a/llvm/lib/Target/BPF/CMakeLists.txt +++ b/llvm/lib/Target/BPF/CMakeLists.txt @@ -14,6 +14,7 @@ add_llvm_target(BPFCodeGen BPFAbstractMemberAccess.cpp + BPFAdjustOpt.cpp BPFAsmPrinter.cpp BPFFrameLowering.cpp BPFInstrInfo.cpp