diff --git a/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp b/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp --- a/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp +++ b/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp @@ -18,8 +18,10 @@ #include "BPF.h" #include "BPFCORE.h" #include "BPFTargetMachine.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -41,12 +43,14 @@ public: static char ID; BPFCheckAndAdjustIR() : ModulePass(ID) {} + virtual void getAnalysisUsage(AnalysisUsage &AU) const override; private: void checkIR(Module &M); bool adjustIR(Module &M); bool removePassThroughBuiltin(Module &M); bool removeCompareBuiltin(Module &M); + bool sinkMinMax(Module &M); }; } // End anonymous namespace @@ -161,9 +165,206 @@ return Changed; } +struct MinMaxSinkInfo { + ICmpInst *ICmp; + Value *Other; + ICmpInst::Predicate Predicate; + CallInst *MinMax; + ZExtInst *ZExt; + SExtInst *SExt; + + MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate) + : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr), + ZExt(nullptr), SExt(nullptr) {} +}; + +static bool sinkMinMaxInBB(BasicBlock &BB, + const std::function &Filter) { + // Check if V is: + // (fn %a %b) or (ext (fn %a %b)) + // Where: + // ext := sext | zext + // fn := smin | umin | smax | umax + auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) { + if (auto *ZExt = dyn_cast(V)) { + V = ZExt->getOperand(0); + Info.ZExt = ZExt; + } else if (auto *SExt = dyn_cast(V)) { + V = SExt->getOperand(0); + Info.SExt = SExt; + } + + auto *Call = dyn_cast(V); + if (!Call) + return false; + + auto *Called = dyn_cast(Call->getCalledOperand()); + if (!Called) + return false; + + switch (Called->getIntrinsicID()) { + case Intrinsic::smin: + case Intrinsic::umin: + case Intrinsic::smax: + case Intrinsic::umax: + break; + default: + return false; + } + + if (!Filter(Call)) + return false; + + Info.MinMax = Call; + + return true; + }; + + auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V, + MinMaxSinkInfo &Info) { + if (Info.SExt) { + if (Info.SExt->getType() == V->getType()) + return V; + return Builder.CreateSExt(V, Info.SExt->getType()); + } + if (Info.ZExt) { + if (Info.ZExt->getType() == V->getType()) + return V; + return Builder.CreateZExt(V, Info.ZExt->getType()); + } + return V; + }; + + bool Changed = false; + SmallVector SinkList; + + // Check BB for instructions like: + // insn := (icmp %a (fn ...)) | (icmp (fn ...) %a) + // + // Where: + // fn := min | max | (sext (min ...)) | (sext (max ...)) + // + // Put such instructions to SinkList. + for (Instruction &I : BB) { + ICmpInst *ICmp = dyn_cast(&I); + if (!ICmp) + continue; + if (!ICmp->isRelational()) + continue; + MinMaxSinkInfo First(ICmp, ICmp->getOperand(1), + ICmpInst::getSwappedPredicate(ICmp->getPredicate())); + MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate()); + bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First); + bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second); + if (!(FirstMinMax ^ SecondMinMax)) + continue; + SinkList.push_back(FirstMinMax ? First : Second); + } + + // Iterate SinkList and replace each (icmp ...) with corresponding + // `x < a && x < b` or similar expression. + for (auto &Info : SinkList) { + ICmpInst *ICmp = Info.ICmp; + CallInst *MinMax = Info.MinMax; + Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID(); + ICmpInst::Predicate P = Info.Predicate; + if (ICmpInst::isSigned(P) && IID != Intrinsic::smin && + IID != Intrinsic::smax) + continue; + + IRBuilder<> Builder(ICmp); + Value *X = Info.Other; + Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info); + Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info); + bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin; + bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax; + bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P); + bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P); + assert(IsMin ^ IsMax); + assert(IsLess ^ IsGreater); + + Value *Replacement; + Value *LHS = Builder.CreateICmp(P, X, A); + Value *RHS = Builder.CreateICmp(P, X, B); + if ((IsLess && IsMin) || (IsGreater && IsMax)) + // x < min(a, b) -> x < a && x < b + // x > max(a, b) -> x > a && x > b + Replacement = Builder.CreateLogicalAnd(LHS, RHS); + else + // x > min(a, b) -> x > a || x > b + // x < max(a, b) -> x < a || x < b + Replacement = Builder.CreateLogicalOr(LHS, RHS); + + ICmp->replaceAllUsesWith(Replacement); + + Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax}; + for (Instruction *I : ToRemove) + if (I && I->use_empty()) + I->eraseFromParent(); + + Changed = true; + } + + return Changed; +} + +// Do the following transformation: +// +// x < min(a, b) -> x < a && x < b +// x > min(a, b) -> x > a || x > b +// x < max(a, b) -> x < a || x < b +// x > max(a, b) -> x > a && x > b +// +// Such patterns are introduced by LICM.cpp:hoistMinMax() +// transformation and might lead to BPF verification failures for +// older kernels. +// +// To minimize "collateral" changes only do it for icmp + min/max +// calls when icmp is inside a loop and min/max is outside of that +// loop. +// +// Verification failure happens when: +// - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1; +// - verifier can recognize RHS as a constant scalar in some context; +// - verifier can't recognize RHS1 as a constant scalar in the same +// context; +// +// The "constant scalar" is not a compile time constant, but a register +// that holds a scalar value known to verifier at some point in time +// during abstract interpretation. +// +// See also: +// https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/ +bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) { + bool Changed = false; + + for (Function &F : M) { + if (F.isDeclaration()) + continue; + + LoopInfo &LI = getAnalysis(F).getLoopInfo(); + for (Loop *L : LI) + for (BasicBlock *BB : L->blocks()) { + // Filter out instructions coming from the same loop + Loop *BBLoop = LI.getLoopFor(BB); + auto OtherLoopFilter = [&](Instruction *I) { + return LI.getLoopFor(I->getParent()) != BBLoop; + }; + Changed |= sinkMinMaxInBB(*BB, OtherLoopFilter); + } + } + + return Changed; +} + +void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); +} + bool BPFCheckAndAdjustIR::adjustIR(Module &M) { bool Changed = removePassThroughBuiltin(M); Changed = removeCompareBuiltin(M) || Changed; + Changed = sinkMinMax(M) || Changed; return Changed; } diff --git a/llvm/test/CodeGen/BPF/sink-min-max.ll b/llvm/test/CodeGen/BPF/sink-min-max.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/BPF/sink-min-max.ll @@ -0,0 +1,258 @@ +; RUN: opt --bpf-check-and-opt-ir -S -mtriple=bpf-pc-linux %s | FileCheck %s + +; Test plan: +; @test1: x < umin(i64 a, i64 b) +; @test2: x < umax(i64 a, i64 b) +; @test3: x >= umin(i64 a, i64 b) +; @test4: x >= umax(i64 a, i64 b) +; @test5: umin(i64 a, i64 b) >= x +; @test6: x < smin(i64 a, i64 b) +; @test7: x < umin(i32 a, i32 b) +; @test8: x < zext i64 umin(i32 a, i32 b) +; @test9: x < sext i64 umin(i32 a, i32 b) +; @test10: check that umin belonging to the same loop is not touched +; @test11: check that nested loops are processed + +define i32 @test1(i64 %a, i64 %b, i64 %x) { +entry: + %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b) + br label %loop +loop: + %cmp = icmp ult i64 %x, %min + br i1 %cmp, label %loop, label %ret +ret: ret i32 0 +} + +; CHECK: @test1 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %loop +; CHECK-EMPTY: +; CHECK-NEXT: loop: +; CHECK-NEXT: %0 = icmp ult i64 %x, %a +; CHECK-NEXT: %1 = icmp ult i64 %x, %b +; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false +; CHECK-NEXT: br i1 %2, label %loop, label %ret + +define i32 @test2(i64 %a, i64 %b, i64 %x) { +entry: + %max = tail call i64 @llvm.umax.i64(i64 %a, i64 %b) + br label %loop +loop: + %cmp = icmp ult i64 %x, %max + br i1 %cmp, label %loop, label %ret +ret: ret i32 0 +} + +; CHECK: @test2 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %loop +; CHECK-EMPTY: +; CHECK-NEXT: loop: +; CHECK-NEXT: %0 = icmp ult i64 %x, %a +; CHECK-NEXT: %1 = icmp ult i64 %x, %b +; CHECK-NEXT: %2 = select i1 %0, i1 true, i1 %1 +; CHECK-NEXT: br i1 %2, label %loop, label %ret + +define i32 @test3(i64 %a, i64 %b, i64 %x) { +entry: + %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b) + br label %loop +loop: + %cmp = icmp uge i64 %x, %min + br i1 %cmp, label %loop, label %ret +ret: ret i32 0 +} + +; CHECK: @test3 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %loop +; CHECK-EMPTY: +; CHECK-NEXT: loop: +; CHECK-NEXT: %0 = icmp uge i64 %x, %a +; CHECK-NEXT: %1 = icmp uge i64 %x, %b +; CHECK-NEXT: %2 = select i1 %0, i1 true, i1 %1 +; CHECK-NEXT: br i1 %2, label %loop, label %ret + +define i32 @test4(i64 %a, i64 %b, i64 %x) { +entry: + %max = tail call i64 @llvm.umax.i64(i64 %a, i64 %b) + br label %loop +loop: + %cmp = icmp uge i64 %x, %max + br i1 %cmp, label %loop, label %ret +ret: ret i32 0 +} + +; CHECK: @test4 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %loop +; CHECK-EMPTY: +; CHECK-NEXT: loop: +; CHECK-NEXT: %0 = icmp uge i64 %x, %a +; CHECK-NEXT: %1 = icmp uge i64 %x, %b +; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false +; CHECK-NEXT: br i1 %2, label %loop, label %ret + +define i32 @test5(i64 %a, i64 %b, i64 %x) { +entry: + %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b) + br label %loop +loop: + %cmp = icmp uge i64 %min, %x + br i1 %cmp, label %loop, label %ret +ret: ret i32 0 +} + +; CHECK: @test5 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %loop +; CHECK-EMPTY: +; CHECK-NEXT: loop: +; CHECK: %0 = icmp ule i64 %x, %a +; CHECK-NEXT: %1 = icmp ule i64 %x, %b +; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false +; CHECK-NEXT: br i1 %2, label %loop, label %ret + +define i32 @test6(i64 %a, i64 %b, i64 %x) { +entry: + %min = tail call i64 @llvm.smin.i64(i64 %a, i64 %b) + br label %loop +loop: + %cmp = icmp slt i64 %x, %min + br i1 %cmp, label %loop, label %ret +ret: ret i32 0 +} + +; CHECK: @test6 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %loop +; CHECK-EMPTY: +; CHECK-NEXT: loop: +; CHECK: %0 = icmp slt i64 %x, %a +; CHECK-NEXT: %1 = icmp slt i64 %x, %b +; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false +; CHECK-NEXT: br i1 %2, label %loop, label %ret + +define i32 @test7(i32 %a, i32 %b, i32 %x) { +entry: + %min = tail call i32 @llvm.umin.i32(i32 %a, i32 %b) + br label %loop +loop: + %cmp = icmp ult i32 %x, %min + br i1 %cmp, label %loop, label %ret +ret: ret i32 0 +} + +; CHECK: @test7 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %loop +; CHECK-EMPTY: +; CHECK-NEXT: loop: +; CHECK: %0 = icmp ult i32 %x, %a +; CHECK-NEXT: %1 = icmp ult i32 %x, %b +; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false +; CHECK-NEXT: br i1 %2, label %loop, label %ret + +define i32 @test8(i32 %a, i32 %b, i64 %x) { +entry: + %min = tail call i32 @llvm.umin.i32(i32 %a, i32 %b) + br label %loop +loop: + %ext = zext i32 %min to i64 + %cmp = icmp ult i64 %x, %ext + br i1 %cmp, label %loop, label %ret +ret: ret i32 0 +} + +; CHECK: @test8 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %loop +; CHECK-EMPTY: +; CHECK-NEXT: loop: +; CHECK-NEXT: %0 = zext i32 %a to i64 +; CHECK-NEXT: %1 = zext i32 %b to i64 +; CHECK-NEXT: %2 = icmp ult i64 %x, %0 +; CHECK-NEXT: %3 = icmp ult i64 %x, %1 +; CHECK-NEXT: %4 = select i1 %2, i1 %3, i1 false +; CHECK-NEXT: br i1 %4, label %loop, label %ret + +define i32 @test9(i32 %a, i32 %b, i64 %x) { +entry: + %min = tail call i32 @llvm.umin.i32(i32 %a, i32 %b) + br label %loop +loop: + %ext = sext i32 %min to i64 + %cmp = icmp ult i64 %x, %ext + br i1 %cmp, label %loop, label %ret +ret: ret i32 0 +} + +; CHECK: @test9 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %loop +; CHECK-EMPTY: +; CHECK-NEXT: loop: +; CHECK-NEXT: %0 = sext i32 %a to i64 +; CHECK-NEXT: %1 = sext i32 %b to i64 +; CHECK-NEXT: %2 = icmp ult i64 %x, %0 +; CHECK-NEXT: %3 = icmp ult i64 %x, %1 +; CHECK-NEXT: %4 = select i1 %2, i1 %3, i1 false +; CHECK-NEXT: br i1 %4, label %loop, label %ret + +; umin within the loop body is unchanged +define i32 @test10(i64 %a, i64 %b, i64 %x) { +entry: + br label %loop +loop: + %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b) + %cmp = icmp ult i64 %x, %min + br i1 %cmp, label %loop, label %ret +ret: ret i32 0 +} + +; CHECK: @test10 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %loop +; CHECK-EMPTY: +; CHECK-NEXT: loop: +; CHECK-NEXT: %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b) +; CHECK-NEXT: %cmp = icmp ult i64 %x, %min +; CHECK-NEXT: br i1 %cmp, label %loop, label %ret + +; umin from outer loop body is processed +define i32 @test11(i64 %a, i64 %b, i64 %x) { +entry: + br label %loop + +loop: + %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b) + br label %nested.loop +nested.loop: + %cmp = icmp ult i64 %x, %min + br i1 %cmp, label %nested.loop, label %loop + +ret: ret i32 0 +} + +; CHECK: @test11 +; CHECK-NEXT: entry: +; CHECK-NEXT: br label %loop +; CHECK-EMPTY: +; CHECK-NEXT: loop: +; CHECK-NEXT: br label %nested.loop +; CHECK-EMPTY: +; CHECK-NEXT: nested.loop: +; CHECK-NEXT: %0 = icmp ult i64 %x, %a +; CHECK-NEXT: %1 = icmp ult i64 %x, %b +; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false +; CHECK-NEXT: br i1 %2, label %nested.loop, label %loop + +declare i64 @llvm.umin.i64(i64, i64) +declare i64 @llvm.smin.i64(i64, i64) +declare i64 @llvm.umax.i64(i64, i64) +declare i64 @llvm.smax.i64(i64, i64) + +declare i32 @llvm.umin.i32(i32, i32) +declare i32 @llvm.smin.i32(i32, i32) +declare i32 @llvm.umax.i32(i32, i32) +declare i32 @llvm.smax.i32(i32, i32)