diff --git a/llvm/lib/Target/BPF/BPF.h b/llvm/lib/Target/BPF/BPF.h --- a/llvm/lib/Target/BPF/BPF.h +++ b/llvm/lib/Target/BPF/BPF.h @@ -15,6 +15,7 @@ namespace llvm { class BPFTargetMachine; +ModulePass *createBPFAdjustOpt(); ModulePass *createBPFAbstractMemberAccess(BPFTargetMachine *TM); ModulePass *createBPFPreserveDIType(); @@ -25,6 +26,7 @@ FunctionPass *createBPFMIPreEmitPeepholePass(); FunctionPass *createBPFMIPreEmitCheckingPass(); +void initializeBPFAdjustOptPass(PassRegistry&); void initializeBPFAbstractMemberAccessPass(PassRegistry&); void initializeBPFPreserveDITypePass(PassRegistry&); void initializeBPFMISimplifyPatchablePass(PassRegistry&); diff --git a/llvm/lib/Target/BPF/BPFAdjustOpt.cpp b/llvm/lib/Target/BPF/BPFAdjustOpt.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/BPF/BPFAdjustOpt.cpp @@ -0,0 +1,180 @@ +//===---------------- BPFAdjustOpt.cpp - Adjust Optimization --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Adjust optimization to make the code more kernel verifier friendly. +// For example, the following optimization is undone: +// - InstCombineAndOrXor +// +//===----------------------------------------------------------------------===// + +#include "BPF.h" +#include "BPFTargetMachine.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +#define DEBUG_TYPE "bpf-adjust-opt" + +using namespace llvm; + +namespace { + +class BPFAdjustOpt final : public ModulePass { + StringRef getPassName() const override { return "BPF Adjust Optimization"; } + + bool runOnModule(Module &M) override; + +public: + static char ID; + BPFAdjustOpt() : ModulePass(ID) {} + +private: + bool Is64BitSigned(APInt &V); + bool adjustOpt(Module &M); + bool adjustInstCombine(Module &M, Function &F); +}; +} // End anonymous namespace + +char BPFAdjustOpt::ID = 0; +INITIALIZE_PASS(BPFAdjustOpt, DEBUG_TYPE, "adjust optimization", false, false) + +ModulePass *llvm::createBPFAdjustOpt() { return new BPFAdjustOpt(); } + +bool BPFAdjustOpt::runOnModule(Module &M) { + LLVM_DEBUG(dbgs() << "******** BPF Adjust Optimization ********\n"); + + return adjustOpt(M); +} + +bool BPFAdjustOpt::Is64BitSigned(APInt &V) { + if (V.isNegative()) + return V.sge(INT64_MIN); + return V.ule(INT64_MAX); +} + +bool BPFAdjustOpt::adjustInstCombine(Module &M, Function &F) { + bool Changed = false; + std::vector RemovedInsts; + + for (auto &BB : F) + for (auto &I : BB) { + if (I.getNumUses() <= 1) + continue; + + // The instruction has more than one use. + // Check whether one use is in a path refined by the + // InstCombine, and undo the optimization. + // The following transformations are supported: + // V - Lo u< Hi - Lo --> V >= Lo && V < Hi + // V - Lo u> Hi - Lo --> V < Lo || V > Hi + // V - Lo u<= Hi - Lo --> V >= Lo && V <= Hi + // V - Lo u>= Hi - Lo --> V < Lo || V >= Hi + for (User *U : I.users()) { + // First instruction (ArithInst): tmp = V - Lo or tmp = V + -Lo + auto *ArithInst = dyn_cast(U); + if (!ArithInst || !ArithInst->hasOneUse()) + continue; + + if (ArithInst->getOpcode() != Instruction::Sub && + ArithInst->getOpcode() != Instruction::Add) + continue; + + const auto *CV = dyn_cast(ArithInst->getOperand(1)); + if (!CV) + continue; + + APInt ValLo = CV->getValue(); + if (ArithInst->getOpcode() == Instruction::Add) + ValLo = -ValLo; + + // Second instruction (CmpInst): tmp u<[=] (Hi - Lo) or tmp u>[=] (Hi - Lo) + auto *CmpInst = dyn_cast(*ArithInst->user_begin()); + if (!CmpInst) + continue; + + // Handle ICMP_ULT as well. In certain + if (CmpInst->getPredicate() != ICmpInst::ICMP_UGT && + CmpInst->getPredicate() != ICmpInst::ICMP_UGE && + CmpInst->getPredicate() != ICmpInst::ICMP_ULT && + CmpInst->getPredicate() != ICmpInst::ICMP_ULE) + continue; + + CV = dyn_cast(CmpInst->getOperand(1)); + if (!CV) + continue; + + APInt ValHi = ValLo + CV->getValue(); + + // Ensure ValLo/ValHi in 64bit signed int range, and Lo < Hi. + if (!Is64BitSigned(ValLo) || !Is64BitSigned(ValHi) || ValHi.sle(ValLo)) + continue; + + Value *NewValLo = ConstantInt::get(CV->getType(), ValLo); + Value *NewValHi = ConstantInt::get(CV->getType(), ValHi); + + // Add new instructions right before ArithInst + IRBuilder<> Builder(ArithInst); + if (CmpInst->getPredicate() == ICmpInst::ICMP_UGT || + CmpInst->getPredicate() == ICmpInst::ICMP_UGE) { + // %call: I + // %cmp1 = icmp slt i32 %call, Lo + // %cmp2 = icmp {sgt|sge} i32 %call, Hi + // %or.cond = or i1 %cmp1, %cmp2 + // replace all users of CmpInst with %or.cond + Value *Cmp1 = Builder.CreateICmpSLT(&I, NewValLo); + Value *Cmp2; + if (CmpInst->getPredicate() == ICmpInst::ICMP_UGT) + Cmp2 = Builder.CreateICmpSGT(&I, NewValHi); + else + Cmp2 = Builder.CreateICmpSGE(&I, NewValHi); + Value *OrCond = Builder.CreateOr(Cmp1, Cmp2); + CmpInst->replaceAllUsesWith(OrCond); + } else { + // %call: I + // %cmp1 = icmp sge i32 %call, Lo + // %cmp2 = icmp {slt|sle} i32 %call, Hi + // %and.cond = and i1 %cmp1, %cmp2 + // replace all users of CmpInst with %and.cond + Value *Cmp1 = Builder.CreateICmpSGE(&I, NewValLo); + Value *Cmp2; + if (CmpInst->getPredicate() == ICmpInst::ICMP_ULT) + Cmp2 = Builder.CreateICmpSLT(&I, NewValHi); + else + Cmp2 = Builder.CreateICmpSLE(&I, NewValHi); + Value *AndCond = Builder.CreateAnd(Cmp1, Cmp2); + CmpInst->replaceAllUsesWith(AndCond); + } + + RemovedInsts.push_back(CmpInst); + RemovedInsts.push_back(ArithInst); + + Changed = true; + break; + } + } + + for (auto *Inst: RemovedInsts) + Inst->eraseFromParent(); + + return Changed; +} + +bool BPFAdjustOpt::adjustOpt(Module &M) { + bool Changed = false; + + for (Function &F : M) { + Changed = adjustInstCombine(M, F) || Changed; + } + + return Changed; +} diff --git a/llvm/lib/Target/BPF/BPFTargetMachine.cpp b/llvm/lib/Target/BPF/BPFTargetMachine.cpp --- a/llvm/lib/Target/BPF/BPFTargetMachine.cpp +++ b/llvm/lib/Target/BPF/BPFTargetMachine.cpp @@ -34,6 +34,7 @@ RegisterTargetMachine Z(getTheBPFTarget()); PassRegistry &PR = *PassRegistry::getPassRegistry(); + initializeBPFAdjustOptPass(PR); initializeBPFAbstractMemberAccessPass(PR); initializeBPFPreserveDITypePass(PR); initializeBPFMIPeepholePass(PR); @@ -96,6 +97,7 @@ void BPFPassConfig::addIRPasses() { + addPass(createBPFAdjustOpt()); addPass(createBPFAbstractMemberAccess(&getBPFTargetMachine())); addPass(createBPFPreserveDIType()); diff --git a/llvm/lib/Target/BPF/CMakeLists.txt b/llvm/lib/Target/BPF/CMakeLists.txt --- a/llvm/lib/Target/BPF/CMakeLists.txt +++ b/llvm/lib/Target/BPF/CMakeLists.txt @@ -14,6 +14,7 @@ add_llvm_target(BPFCodeGen BPFAbstractMemberAccess.cpp + BPFAdjustOpt.cpp BPFAsmPrinter.cpp BPFFrameLowering.cpp BPFInstrInfo.cpp diff --git a/llvm/test/CodeGen/BPF/adjust-instcombine-1.ll b/llvm/test/CodeGen/BPF/adjust-instcombine-1.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/BPF/adjust-instcombine-1.ll @@ -0,0 +1,64 @@ +; RUN: llc < %s -march=bpfel | FileCheck --check-prefixes=CHECK,CHECK-V1 %s +; RUN: llc < %s -march=bpfel -mcpu=v3 | FileCheck --check-prefixes=CHECK,CHECK-V3 %s +; +; Source Code: +; char value[7]; +; extern int ext_test(void *); +; int test() { +; int i, ret, off = 0; +; +; #pragma clang loop unroll(disable) +; for (i = 0; i < 50; ++i) { +; ret = ext_test(value + off); +; if (ret <= 0 || ret > 7) +; return 0; +; off += ret & 7; +; } +; return 0; +; } +; Compilation flag: +; clang -target bpf -O2 -S -emit-llvm test.c + +@value = common dso_local global [7 x i8] zeroinitializer, align 1 +; Function Attrs: nounwind +define dso_local i32 @test() local_unnamed_addr #0 { +entry: + br label %for.body + +for.body: ; preds = %if.end, %entry + %off.013 = phi i32 [ 0, %entry ], [ %add, %if.end ] + %i.012 = phi i32 [ 0, %entry ], [ %inc, %if.end ] + %idx.ext = zext i32 %off.013 to i64 + %add.ptr = getelementptr inbounds [7 x i8], [7 x i8]* @value, i64 0, i64 %idx.ext + %call = tail call i32 @ext_test(i8* nonnull %add.ptr) #2 + %call.off = add i32 %call, -1 + %0 = icmp ugt i32 %call.off, 6 + br i1 %0, label %cleanup, label %if.end +; CHECK: call ext_test +; CHECK-V1: r0 <<= 32 +; CHECK-V1: r0 s>>= 32 +; CHECK-V1: if r{{[0-9]+}} s> 7 goto +; CHECK-V3: if w{{[0-9]+}} s> 7 goto + +if.end: ; preds = %for.body + %add = add nuw nsw i32 %call, %off.013 + %inc = add nuw nsw i32 %i.012, 1 + %exitcond = icmp eq i32 %inc, 50 + br i1 %exitcond, label %cleanup, label %for.body, !llvm.loop !2 + +cleanup: ; preds = %if.end, %for.body + ret i32 0 +} +declare dso_local i32 @ext_test(i8*) local_unnamed_addr #1 + +attributes #0 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #1 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #2 = { nounwind } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 11.0.0 (https://github.com/llvm/llvm-project.git ca2e7bed9f7835eec20acfdbabed0617126561aa)"} +!2 = distinct !{!2, !3} +!3 = !{!"llvm.loop.unroll.disable"} diff --git a/llvm/test/CodeGen/BPF/adjust-instcombine-2.ll b/llvm/test/CodeGen/BPF/adjust-instcombine-2.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/BPF/adjust-instcombine-2.ll @@ -0,0 +1,48 @@ +; RUN: llc < %s -march=bpfel | FileCheck --check-prefix=CHECK %s +; RUN: llc < %s -march=bpfel -mcpu=v3 | FileCheck --check-prefix=CHECK-V3 %s +; +; Source Code: +; int ext_test(int); +; int test(int *len) { +; char options[10] = {}; +; int options_len = *len; +; if (options_len < 4 || options_len > 10) return 0; +; return ext_test(options_len); +; } +; Compilation flag: +; clang -target bpf -O2 -S -emit-llvm test.c + +; Function Attrs: nounwind +define dso_local i32 @test(i32* nocapture readonly %len) local_unnamed_addr #0 { +entry: + %0 = load i32, i32* %len, align 4, !tbaa !2 + %.off = add i32 %0, -4 + %1 = icmp ugt i32 %.off, 6 + br i1 %1, label %cleanup, label %if.end + +; CHECK: if r{{[0-9]+}} s> 10 +; CHECK-V3: if w{{[0-9]+}} s> 10 + +if.end: ; preds = %entry + %call = tail call i32 @ext_test(i32 %0) #2 + br label %cleanup + +cleanup: ; preds = %entry, %if.end + %retval.0 = phi i32 [ %call, %if.end ], [ 0, %entry ] + ret i32 %retval.0 +} +declare dso_local i32 @ext_test(i32) local_unnamed_addr #1 + +attributes #0 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #1 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #2 = { nounwind } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 11.0.0 (https://github.com/llvm/llvm-project.git ca2e7bed9f7835eec20acfdbabed0617126561aa)"} +!2 = !{!3, !3, i64 0} +!3 = !{!"int", !4, i64 0} +!4 = !{!"omnipotent char", !5, i64 0} +!5 = !{!"Simple C/C++ TBAA"}