diff --git a/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp b/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp --- a/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp +++ b/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp @@ -20,6 +20,7 @@ #include "BPFTargetMachine.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -47,6 +48,7 @@ bool adjustIR(Module &M); bool removePassThroughBuiltin(Module &M); bool removeCompareBuiltin(Module &M); + bool sinkMinMax(Module &M); }; } // End anonymous namespace @@ -161,9 +163,154 @@ return Changed; } +struct MinMaxSinkInfo { + ICmpInst *ICmp; + Value *Other; + ICmpInst::Predicate Predicate; + CallInst *MinMax; + ZExtInst *ZExt; + SExtInst *SExt; + + MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate) + : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr), + ZExt(nullptr), SExt(nullptr) {} +}; + +static bool sinkMinMaxInBB(BasicBlock &BB) { + auto IsMinMaxCall = [](Value *V, MinMaxSinkInfo &Info) { + if (auto *ZExt = dyn_cast(V)) { + V = ZExt->getOperand(0); + Info.ZExt = ZExt; + } else if (auto *SExt = dyn_cast(V)) { + V = SExt->getOperand(0); + Info.SExt = SExt; + } + + auto *Call = dyn_cast(V); + if (!Call) + return false; + + auto *Called = dyn_cast(Call->getCalledOperand()); + if (!Called) + return false; + + switch (Called->getIntrinsicID()) { + case Intrinsic::smin: + case Intrinsic::umin: + case Intrinsic::smax: + case Intrinsic::umax: + break; + default: + return false; + } + + Info.MinMax = Call; + + return (isa(Call->getOperand(0)) || + isa(Call->getOperand(1))); + }; + + auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V, + MinMaxSinkInfo &Info) { + if (Info.SExt) { + if (Info.SExt->getType() == V->getType()) + return V; + return Builder.CreateSExt(V, Info.SExt->getType()); + } + if (Info.ZExt) { + if (Info.ZExt->getType() == V->getType()) + return V; + return Builder.CreateZExt(V, Info.ZExt->getType()); + } + return V; + }; + + bool Changed = false; + SmallVector SinkList; + + for (Instruction &I : BB) { + ICmpInst *ICmp = dyn_cast(&I); + if (!ICmp) + continue; + if (!ICmp->isRelational()) + continue; + MinMaxSinkInfo First(ICmp, ICmp->getOperand(1), + ICmpInst::getSwappedPredicate(ICmp->getPredicate())); + MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate()); + bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First); + bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second); + if (!(FirstMinMax ^ SecondMinMax)) + continue; + SinkList.push_back(FirstMinMax ? First : Second); + } + + for (auto &Info : SinkList) { + ICmpInst *ICmp = Info.ICmp; + CallInst *MinMax = Info.MinMax; + Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID(); + ICmpInst::Predicate P = Info.Predicate; + if (ICmpInst::isSigned(P) && IID != Intrinsic::smin && + IID != Intrinsic::smax) + continue; + + IRBuilder<> Builder(ICmp); + Value *X = Info.Other; + Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info); + Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info); + bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin; + bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax; + bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P); + bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P); + assert(IsMin ^ IsMax); + assert(IsLess ^ IsGreater); + Value *Replacement; + if ((IsLess && IsMin) || (IsGreater && IsMax)) + // x < min(a, b) -> x < a && x < b + // x > max(a, b) -> x > a && x > b + Replacement = Builder.CreateLogicalAnd(Builder.CreateICmp(P, X, A), + Builder.CreateICmp(P, X, B)); + else + // x > min(a, b) -> x > a || x > b + // x < max(a, b) -> x < a || x < b + Replacement = Builder.CreateLogicalOr(Builder.CreateICmp(P, X, A), + Builder.CreateICmp(P, X, B)); + ICmp->replaceAllUsesWith(Replacement); + Instruction *ToRemove[] = {ICmp, MinMax, Info.ZExt, Info.SExt}; + for (Instruction *I : ToRemove) + if (I && I->use_empty()) { + I->dropAllReferences(); + I->removeFromParent(); + } + + Changed = true; + } + + return Changed; +} + +// Do the following transformation: +// x < min(a, b) -> x < a && x < b +// x > min(a, b) -> x > a || x > b +// x < max(a, b) -> x < a || x < b +// x > max(a, b) -> x > a && x > b +// If a or b is a constant. +bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) { + bool Changed = false; + + for (Function &F : M) { + if (F.isDeclaration()) + continue; + for (BasicBlock &BB : F) + Changed |= sinkMinMaxInBB(BB); + } + + return Changed; +} + bool BPFCheckAndAdjustIR::adjustIR(Module &M) { bool Changed = removePassThroughBuiltin(M); Changed = removeCompareBuiltin(M) || Changed; + Changed = sinkMinMax(M) || Changed; return Changed; } diff --git a/llvm/test/CodeGen/BPF/sink-min-max-u64.ll b/llvm/test/CodeGen/BPF/sink-min-max-u64.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/BPF/sink-min-max-u64.ll @@ -0,0 +1,179 @@ +; RUN: opt --bpf-check-and-opt-ir -S -mtriple=bpf-pc-linux %s -o - \ +; RUN: | FileCheck %s + +; Test plan: +; @test1: u64 x < min(a, const) +; @test2: x < max(a, const) +; @test3: x >= min(a, const) +; @test4: x >= max(a, const) +; @test5: min(a, const) >= x +; @test6: x < min(a, b) +; @test7: s64 x < min(a, const) +; @test8: u32 x < min(a, const) +; @test9: u32 x < zext 32 min(a, const) +; @test10: s32 x < sext 32 min(a, const) + +define i32 @test1(i64 %a, i64 %x) { +entry: + %min = tail call i64 @llvm.umin.i64(i64 %a, i64 8) + %cmp = icmp ult i64 %x, %min + br i1 %cmp, label %ret0, label %ret1 +ret1: ret i32 1 +ret0: ret i32 0 +} + +; CHECK: @test1 +; CHECK: %0 = icmp ult i64 %x, %a +; CHECK-NEXT: %1 = icmp ult i64 %x, 8 +; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false +; CHECK-NEXT: br i1 %2, label %ret0, label %ret1 + +define i32 @test2(i64 %a, i64 %x) { +entry: + %max = tail call i64 @llvm.umax.i64(i64 %a, i64 8) + %cmp = icmp ult i64 %x, %max + br i1 %cmp, label %ret0, label %ret1 +ret1: ret i32 1 +ret0: ret i32 0 +} + +; CHECK: @test2 +; CHECK: %0 = icmp ult i64 %x, %a +; CHECK-NEXT: %1 = icmp ult i64 %x, 8 +; CHECK-NEXT: %2 = select i1 %0, i1 true, i1 %1 +; CHECK-NEXT: br i1 %2, label %ret0, label %ret1 + +define i32 @test3(i64 %a, i64 %x) { +entry: + %min = tail call i64 @llvm.umin.i64(i64 %a, i64 8) + %cmp = icmp uge i64 %x, %min + br i1 %cmp, label %ret0, label %ret1 +ret1: ret i32 1 +ret0: ret i32 0 +} + +; CHECK: @test3 +; CHECK: %0 = icmp uge i64 %x, %a +; CHECK-NEXT: %1 = icmp uge i64 %x, 8 +; CHECK-NEXT: %2 = select i1 %0, i1 true, i1 %1 +; CHECK-NEXT: br i1 %2, label %ret0, label %ret1 + +define i32 @test4(i64 %a, i64 %x) { +entry: + %max = tail call i64 @llvm.umax.i64(i64 %a, i64 8) + %cmp = icmp uge i64 %x, %max + br i1 %cmp, label %ret0, label %ret1 +ret1: ret i32 1 +ret0: ret i32 0 +} + +; CHECK: @test4 +; CHECK: %0 = icmp uge i64 %x, %a +; CHECK-NEXT: %1 = icmp uge i64 %x, 8 +; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false +; CHECK-NEXT: br i1 %2, label %ret0, label %ret1 + +define i32 @test5(i64 %a, i64 %x) { +entry: + %min = tail call i64 @llvm.umin.i64(i64 %a, i64 8) + %cmp = icmp uge i64 %min, %x + br i1 %cmp, label %ret0, label %ret1 +ret1: ret i32 1 +ret0: ret i32 0 +} + +; CHECK: @test5 +; CHECK: %0 = icmp ule i64 %x, %a +; CHECK-NEXT: %1 = icmp ule i64 %x, 8 +; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false +; CHECK-NEXT: br i1 %2, label %ret0, label %ret1 + +define i32 @test6(i64 %a, i64 %b, i64 %x) { +entry: + %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b) + %cmp = icmp ult i64 %x, %min + br i1 %cmp, label %ret0, label %ret1 +ret1: ret i32 1 +ret0: ret i32 0 +} + +; CHECK: @test6 +; CHECK: %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b) +; CHECK-NEXT: %cmp = icmp ult i64 %x, %min +; CHECK-NEXT: br i1 %cmp, label %ret0, label %ret1 + +define i32 @test7(i64 %a, i64 %x) { +entry: + %min = tail call i64 @llvm.smin.i64(i64 %a, i64 8) + %cmp = icmp slt i64 %x, %min + br i1 %cmp, label %ret0, label %ret1 +ret1: ret i32 1 +ret0: ret i32 0 +} + +; CHECK: @test7 +; CHECK: %0 = icmp slt i64 %x, %a +; CHECK-NEXT: %1 = icmp slt i64 %x, 8 +; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false +; CHECK-NEXT: br i1 %2, label %ret0, label %ret1 + +define i32 @test8(i32 %a, i32 %x) { +entry: + %min = tail call i32 @llvm.umin.i32(i32 %a, i32 8) + %cmp = icmp ult i32 %x, %min + br i1 %cmp, label %ret0, label %ret1 +ret1: ret i32 1 +ret0: ret i32 0 +} + +; CHECK: @test8 +; CHECK: %0 = icmp ult i32 %x, %a +; CHECK-NEXT: %1 = icmp ult i32 %x, 8 +; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false +; CHECK-NEXT: br i1 %2, label %ret0, label %ret1 + +define i32 @test9(i32 %a, i64 %x) { +entry: + %min = tail call i32 @llvm.umin.i32(i32 %a, i32 8) + %ext = zext i32 %min to i64 + %cmp = icmp ult i64 %x, %ext + br i1 %cmp, label %ret0, label %ret1 +ret1: ret i32 1 +ret0: ret i32 0 +} + +; CHECK: @test9 +; CHECK: %min = tail call i32 @llvm.umin.i32(i32 %a, i32 8) +; CHECK-NEXT: %0 = zext i32 %a to i64 +; CHECK-NEXT: %1 = icmp ult i64 %x, %0 +; CHECK-NEXT: %2 = icmp ult i64 %x, 8 +; CHECK-NEXT: %3 = select i1 %1, i1 %2, i1 false +; CHECK-NEXT: br i1 %3, label %ret0, label %ret1 + +define i32 @test10(i32 %a, i64 %x) { +entry: + %min = tail call i32 @llvm.umin.i32(i32 %a, i32 8) + %ext = sext i32 %min to i64 + %cmp = icmp ult i64 %x, %ext + br i1 %cmp, label %ret0, label %ret1 +ret1: ret i32 1 +ret0: ret i32 0 +} + +; CHECK @test10 +; CHECK %min = tail call i32 @llvm.umin.i32(i32 %a, i32 8) +; CHECK-NEXT %0 = sext i32 %a to i64 +; CHECK-NEXT %1 = icmp ult i64 %x, %0 +; CHECK-NEXT %2 = icmp ult i64 %x, 8 +; CHECK-NEXT %3 = select i1 %1, i1 %2, i1 false +; CHECK-NEXT br i1 %3, label %ret0, label %ret1 + +declare i64 @llvm.umin.i64(i64, i64) +declare i64 @llvm.smin.i64(i64, i64) +declare i64 @llvm.umax.i64(i64, i64) +declare i64 @llvm.smax.i64(i64, i64) + +declare i32 @llvm.umin.i32(i32, i32) +declare i32 @llvm.smin.i32(i32, i32) +declare i32 @llvm.umax.i32(i32, i32) +declare i32 @llvm.smax.i32(i32, i32)