diff --git a/llvm/include/llvm/IR/IntrinsicsBPF.td b/llvm/include/llvm/IR/IntrinsicsBPF.td --- a/llvm/include/llvm/IR/IntrinsicsBPF.td +++ b/llvm/include/llvm/IR/IntrinsicsBPF.td @@ -34,4 +34,7 @@ [IntrNoMem]>; def int_bpf_passthrough : GCCBuiltin<"__builtin_bpf_passthrough">, Intrinsic<[llvm_any_ty], [llvm_i32_ty, llvm_any_ty], [IntrNoMem]>; + def int_bpf_compare : GCCBuiltin<"__builtin_bpf_compare">, + Intrinsic<[llvm_i1_ty], [llvm_i32_ty, llvm_anyint_ty, llvm_anyint_ty], + [IntrNoMem]>; } diff --git a/llvm/lib/Target/BPF/BPFAdjustOpt.cpp b/llvm/lib/Target/BPF/BPFAdjustOpt.cpp --- a/llvm/lib/Target/BPF/BPFAdjustOpt.cpp +++ b/llvm/lib/Target/BPF/BPFAdjustOpt.cpp @@ -15,6 +15,7 @@ #include "BPFTargetMachine.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicsBPF.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" @@ -66,6 +67,7 @@ Module *M; SmallVector PassThroughs; + bool adjustICmpToBuiltin(); void adjustBasicBlock(BasicBlock &BB); bool serializeICMPCrossBB(BasicBlock &BB); void adjustInst(Instruction &I); @@ -85,14 +87,72 @@ bool BPFAdjustOpt::runOnModule(Module &M) { return BPFAdjustOptImpl(&M).run(); } bool BPFAdjustOptImpl::run() { + bool Changed = adjustICmpToBuiltin(); + for (Function &F : *M) for (auto &BB : F) { adjustBasicBlock(BB); for (auto &I : BB) adjustInst(I); } + return insertPassThrough() || Changed; +} + +// Commit acabad9ff6bf ("[InstCombine] try to canonicalize icmp with +// trunc op into mask and cmp") added a transformation to +// convert "(conv)a < power_2_const" to "a & " in certain +// cases and bpf kernel verifier has to handle the resulted code +// conservatively and this may reject otherwise legitimate program. +// Here, we change related icmp code to a builtin which will +// be restored to original icmp code later to prevent that +// InstCombine transformatin. +bool BPFAdjustOptImpl::adjustICmpToBuiltin() { + bool Changed = false; + ICmpInst *ToBeDeleted = nullptr; + for (Function &F : *M) + for (auto &BB : F) + for (auto &I : BB) { + if (ToBeDeleted) { + ToBeDeleted->eraseFromParent(); + ToBeDeleted = nullptr; + } + + auto *Icmp = dyn_cast(&I); + if (!Icmp) + continue; + + Value *Op0 = Icmp->getOperand(0); + if (!isa(Op0)) + continue; + + auto ConstOp1 = dyn_cast(Icmp->getOperand(1)); + if (!ConstOp1) + continue; + + auto ConstOp1Val = ConstOp1->getValue().getZExtValue(); + auto Op = Icmp->getPredicate(); + if (Op == ICmpInst::ICMP_ULT) { + if ((ConstOp1Val - 1) & ConstOp1Val) + continue; + } else if (Op == ICmpInst::ICMP_ULE) { + if (ConstOp1Val & (ConstOp1Val + 1)) + continue; + } else { + continue; + } + + Constant *Opcode = + ConstantInt::get(Type::getInt32Ty(BB.getContext()), Op); + Function *Fn = Intrinsic::getDeclaration( + M, Intrinsic::bpf_compare, {Op0->getType(), ConstOp1->getType()}); + auto *NewInst = CallInst::Create(Fn, {Opcode, Op0, ConstOp1}); + BB.getInstList().insert(I.getIterator(), NewInst); + Icmp->replaceAllUsesWith(NewInst); + Changed = true; + ToBeDeleted = Icmp; + } - return insertPassThrough(); + return Changed; } bool BPFAdjustOptImpl::insertPassThrough() { diff --git a/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp b/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp --- a/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp +++ b/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp @@ -46,6 +46,7 @@ void checkIR(Module &M); bool adjustIR(Module &M); bool removePassThroughBuiltin(Module &M); + bool removeCompareBuiltin(Module &M); }; } // End anonymous namespace @@ -120,8 +121,50 @@ return Changed; } +bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) { + // Remove __builtin_bpf_compare()'s which are used to prevent + // certain IR optimizations. Now major IR optimizations are done, + // remove them. + bool Changed = false; + CallInst *ToBeDeleted = nullptr; + for (Function &F : M) + for (auto &BB : F) + for (auto &I : BB) { + if (ToBeDeleted) { + ToBeDeleted->eraseFromParent(); + ToBeDeleted = nullptr; + } + + auto *Call = dyn_cast(&I); + if (!Call) + continue; + auto *GV = dyn_cast(Call->getCalledOperand()); + if (!GV) + continue; + if (!GV->getName().startswith("llvm.bpf.compare")) + continue; + + Changed = true; + Value *Arg0 = Call->getArgOperand(0); + Value *Arg1 = Call->getArgOperand(1); + Value *Arg2 = Call->getArgOperand(2); + + auto OpVal = cast(Arg0)->getValue().getZExtValue(); + CmpInst::Predicate Opcode = (CmpInst::Predicate)OpVal; + + auto *ICmp = new ICmpInst(Opcode, Arg1, Arg2); + BB.getInstList().insert(Call->getIterator(), ICmp); + + Call->replaceAllUsesWith(ICmp); + ToBeDeleted = Call; + } + return Changed; +} + bool BPFCheckAndAdjustIR::adjustIR(Module &M) { - return removePassThroughBuiltin(M); + bool Changed = removePassThroughBuiltin(M); + Changed = removeCompareBuiltin(M) || Changed; + return Changed; } bool BPFCheckAndAdjustIR::runOnModule(Module &M) { diff --git a/llvm/test/CodeGen/BPF/adjust-opt-icmp3.ll b/llvm/test/CodeGen/BPF/adjust-opt-icmp3.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/BPF/adjust-opt-icmp3.ll @@ -0,0 +1,85 @@ +; RUN: opt -O2 -S -mtriple=bpf-pc-linux %s -o %t1 +; RUN: llc %t1 -o - | FileCheck -check-prefixes=CHECK,CHECK-V1 %s +; RUN: opt -O2 -S -mtriple=bpf-pc-linux %s -o %t1 +; RUN: llc %t1 -mcpu=v3 -o - | FileCheck -check-prefixes=CHECK,CHECK-V3 %s +; +; Source: +; int test1(unsigned long a) { +; if ((unsigned)a <= 3) return 2; +; return 3; +; } +; int test2(unsigned long a) { +; if ((unsigned)a < 4) return 2; +; return 3; +; } +; Compilation flag: +; clang -target bpf -O2 -S -emit-llvm -Xclang -disable-llvm-passes test.c + +; Function Attrs: nounwind +define dso_local i32 @test1(i64 %a) #0 { +entry: + %retval = alloca i32, align 4 + %a.addr = alloca i64, align 8 + store i64 %a, i64* %a.addr, align 8, !tbaa !3 + %0 = load i64, i64* %a.addr, align 8, !tbaa !3 + %conv = trunc i64 %0 to i32 + %cmp = icmp ule i32 %conv, 3 + br i1 %cmp, label %if.then, label %if.end + +if.then: ; preds = %entry + store i32 2, i32* %retval, align 4 + br label %return + +if.end: ; preds = %entry + store i32 3, i32* %retval, align 4 + br label %return + +return: ; preds = %if.end, %if.then + %1 = load i32, i32* %retval, align 4 + ret i32 %1 +} + +; CHECK-LABEL: test1 +; CHECK-V1: if r[[#]] > r[[#]] goto +; CHECK-V3: if w[[#]] < 4 goto + +; Function Attrs: nounwind +define dso_local i32 @test2(i64 %a) #0 { +entry: + %retval = alloca i32, align 4 + %a.addr = alloca i64, align 8 + store i64 %a, i64* %a.addr, align 8, !tbaa !3 + %0 = load i64, i64* %a.addr, align 8, !tbaa !3 + %conv = trunc i64 %0 to i32 + %cmp = icmp ult i32 %conv, 4 + br i1 %cmp, label %if.then, label %if.end + +if.then: ; preds = %entry + store i32 2, i32* %retval, align 4 + br label %return + +if.end: ; preds = %entry + store i32 3, i32* %retval, align 4 + br label %return + +return: ; preds = %if.end, %if.then + %1 = load i32, i32* %retval, align 4 + ret i32 %1 +} + +; CHECK-LABEL: test2 +; CHECK-V1: if r[[#]] > r[[#]] goto +; CHECK-V3: if w[[#]] < 4 goto + +attributes #0 = { nounwind "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } + +!llvm.module.flags = !{!0, !1} +!llvm.ident = !{!2} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{i32 7, !"frame-pointer", i32 2} +!2 = !{!"clang version 14.0.0 (https://github.com/llvm/llvm-project.git b7892f95881c891032742e0cd81861b845512653)"} +!3 = !{!4, !4, i64 0} +!4 = !{!"long", !5, i64 0} +!5 = !{!"omnipotent char", !6, i64 0} +!6 = !{!"Simple C/C++ TBAA"}