Index: lib/Transforms/InstCombine/InstCombine.h =================================================================== --- lib/Transforms/InstCombine/InstCombine.h +++ lib/Transforms/InstCombine/InstCombine.h @@ -162,6 +162,7 @@ Instruction *visitUDiv(BinaryOperator &I); Instruction *visitSDiv(BinaryOperator &I); Instruction *visitFDiv(BinaryOperator &I); + Value *simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, bool Inverted); Value *FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS); Value *FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Instruction *visitAnd(BinaryOperator &I); Index: lib/Transforms/InstCombine/InstCombineAndOrXor.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -785,6 +785,63 @@ return nullptr; } +/// Try to fold a signed range checked with lower bound 0 to an unsigned icmp. +/// Example: (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n +/// If \p Inverted is true then the check is for the inverted range, e.g. +/// (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n +Value *InstCombiner::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool Inverted) { + // Check the lower range comparison, e.g. x >= 0 + // InstCombine already ensured that if there is a constant it's on the RHS. + ConstantInt *RangeStart = dyn_cast(Cmp0->getOperand(1)); + if (!RangeStart) + return nullptr; + + ICmpInst::Predicate Pred0 = (Inverted ? Cmp0->getInversePredicate() : + Cmp0->getPredicate()); + + // Accept x >= 0 or x > -1 (after potentially inverting the predicate). + if (!(Pred0 == ICmpInst::ICMP_SGT && RangeStart->isMinusOne()) && + !(Pred0 == ICmpInst::ICMP_SGE && RangeStart->isZero())) + return nullptr; + + ICmpInst::Predicate Pred1 = (Inverted ? Cmp1->getInversePredicate() : + Cmp1->getPredicate()); + + Value *Input = Cmp0->getOperand(0); + Value *RangeEnd; + if (Cmp1->getOperand(0) == Input) { + // For the upper range compare we have: icmp x, n + RangeEnd = Cmp1->getOperand(1); + } else if (Cmp1->getOperand(1) == Input) { + // For the upper range compare we have: icmp n, x + RangeEnd = Cmp1->getOperand(0); + Pred1 = ICmpInst::getSwappedPredicate(Pred1); + } else { + return nullptr; + } + + // Check the upper range comparison, e.g. x < n + ICmpInst::Predicate NewPred; + switch (Pred1) { + case ICmpInst::ICMP_SLT: NewPred = ICmpInst::ICMP_ULT; break; + case ICmpInst::ICMP_SLE: NewPred = ICmpInst::ICMP_ULE; break; + default: return nullptr; + } + + // This simplification is only valid if the upper range is not negative. + bool IsNegative, IsNotNegative; + ComputeSignBit(RangeEnd, IsNotNegative, IsNegative, DL, 0, AT, + Cmp1, DT); + if (!IsNotNegative) + return nullptr; + + if (Inverted) + NewPred = ICmpInst::getInversePredicate(NewPred); + + return Builder->CreateICmp(NewPred, Input, RangeEnd); +} + /// FoldAndOfICmps - Fold (icmp)&(icmp) if possible. Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); @@ -807,6 +864,14 @@ if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, true, Builder)) return V; + // E.g. (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n + if (Value *V = simplifyRangeCheck(LHS, RHS, false)) + return V; + + // E.g. (icmp slt x, n) & (icmp sge x, 0) --> icmp ult x, n + if (Value *V = simplifyRangeCheck(RHS, LHS, false)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0); ConstantInt *LHSCst = dyn_cast(LHS->getOperand(1)); @@ -1724,6 +1789,14 @@ Builder->CreateAdd(B, ConstantInt::getSigned(B->getType(), -1)), A); } + // E.g. (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n + if (Value *V = simplifyRangeCheck(LHS, RHS, true)) + return V; + + // E.g. (icmp sgt x, n) | (icmp slt x, 0) --> icmp ugt x, n + if (Value *V = simplifyRangeCheck(RHS, LHS, true)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). if (!LHSCst || !RHSCst) return nullptr; Index: test/Transforms/InstCombine/range-check.ll =================================================================== --- test/Transforms/InstCombine/range-check.ll +++ test/Transforms/InstCombine/range-check.ll @@ -0,0 +1,159 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +; Check simplification of +; (icmp sgt x, -1) & (icmp sgt/sge n, x) --> icmp ugt/uge n, x + +; CHECK-LABEL: define i1 @test_and1 +; CHECK: [[R:%[0-9]+]] = icmp ugt i32 %nn, %x +; CHECK: ret i1 [[R]] +define i1 @test_and1(i32 %x, i32 %n) { + %nn = and i32 %n, 2147483647 + %a = icmp sge i32 %x, 0 + %b = icmp slt i32 %x, %nn + %c = and i1 %a, %b + ret i1 %c +} + +; CHECK-LABEL: define i1 @test_and2 +; CHECK: [[R:%[0-9]+]] = icmp uge i32 %nn, %x +; CHECK: ret i1 [[R]] +define i1 @test_and2(i32 %x, i32 %n) { + %nn = and i32 %n, 2147483647 + %a = icmp sgt i32 %x, -1 + %b = icmp sle i32 %x, %nn + %c = and i1 %a, %b + ret i1 %c +} + +; CHECK-LABEL: define i1 @test_and3 +; CHECK: [[R:%[0-9]+]] = icmp ugt i32 %nn, %x +; CHECK: ret i1 [[R]] +define i1 @test_and3(i32 %x, i32 %n) { + %nn = and i32 %n, 2147483647 + %a = icmp sgt i32 %nn, %x + %b = icmp sge i32 %x, 0 + %c = and i1 %a, %b + ret i1 %c +} + +; CHECK-LABEL: define i1 @test_and4 +; CHECK: [[R:%[0-9]+]] = icmp uge i32 %nn, %x +; CHECK: ret i1 [[R]] +define i1 @test_and4(i32 %x, i32 %n) { + %nn = and i32 %n, 2147483647 + %a = icmp sge i32 %nn, %x + %b = icmp sge i32 %x, 0 + %c = and i1 %a, %b + ret i1 %c +} + +; CHECK-LABEL: define i1 @test_or1 +; CHECK: [[R:%[0-9]+]] = icmp ule i32 %nn, %x +; CHECK: ret i1 [[R]] +define i1 @test_or1(i32 %x, i32 %n) { + %nn = and i32 %n, 2147483647 + %a = icmp slt i32 %x, 0 + %b = icmp sge i32 %x, %nn + %c = or i1 %a, %b + ret i1 %c +} + +; CHECK-LABEL: define i1 @test_or2 +; CHECK: [[R:%[0-9]+]] = icmp ult i32 %nn, %x +; CHECK: ret i1 [[R]] +define i1 @test_or2(i32 %x, i32 %n) { + %nn = and i32 %n, 2147483647 + %a = icmp sle i32 %x, -1 + %b = icmp sgt i32 %x, %nn + %c = or i1 %a, %b + ret i1 %c +} + +; CHECK-LABEL: define i1 @test_or3 +; CHECK: [[R:%[0-9]+]] = icmp ule i32 %nn, %x +; CHECK: ret i1 [[R]] +define i1 @test_or3(i32 %x, i32 %n) { + %nn = and i32 %n, 2147483647 + %a = icmp sle i32 %nn, %x + %b = icmp slt i32 %x, 0 + %c = or i1 %a, %b + ret i1 %c +} + +; CHECK-LABEL: define i1 @test_or4 +; CHECK: [[R:%[0-9]+]] = icmp ult i32 %nn, %x +; CHECK: ret i1 [[R]] +define i1 @test_or4(i32 %x, i32 %n) { + %nn = and i32 %n, 2147483647 + %a = icmp slt i32 %nn, %x + %b = icmp slt i32 %x, 0 + %c = or i1 %a, %b + ret i1 %c +} + +; Negative tests + +; CHECK-LABEL: define i1 @negative1 +; CHECK: %a = icmp +; CHECK: %b = icmp +; CHECK: %c = and i1 %a, %b +; CHECK: ret i1 %c +define i1 @negative1(i32 %x, i32 %n) { + %nn = and i32 %n, 2147483647 + %a = icmp slt i32 %x, %nn + %b = icmp sgt i32 %x, 0 ; should be: icmp sge + %c = and i1 %a, %b + ret i1 %c +} + +; CHECK-LABEL: define i1 @negative2 +; CHECK: %a = icmp +; CHECK: %b = icmp +; CHECK: %c = and i1 %a, %b +; CHECK: ret i1 %c +define i1 @negative2(i32 %x, i32 %n) { + %a = icmp slt i32 %x, %n ; n can be negative + %b = icmp sge i32 %x, 0 + %c = and i1 %a, %b + ret i1 %c +} + +; CHECK-LABEL: define i1 @negative3 +; CHECK: %a = icmp +; CHECK: %b = icmp +; CHECK: %c = and i1 %a, %b +; CHECK: ret i1 %c +define i1 @negative3(i32 %x, i32 %y, i32 %n) { + %nn = and i32 %n, 2147483647 + %a = icmp slt i32 %x, %nn + %b = icmp sge i32 %y, 0 ; should compare %x and not %y + %c = and i1 %a, %b + ret i1 %c +} + +; CHECK-LABEL: define i1 @negative4 +; CHECK: %a = icmp +; CHECK: %b = icmp +; CHECK: %c = and i1 %a, %b +; CHECK: ret i1 %c +define i1 @negative4(i32 %x, i32 %n) { + %nn = and i32 %n, 2147483647 + %a = icmp ne i32 %x, %nn ; should be: icmp slt/sle + %b = icmp sge i32 %x, 0 + %c = and i1 %a, %b + ret i1 %c +} + +; CHECK-LABEL: define i1 @negative5 +; CHECK: %a = icmp +; CHECK: %b = icmp +; CHECK: %c = or i1 %a, %b +; CHECK: ret i1 %c +define i1 @negative5(i32 %x, i32 %n) { + %nn = and i32 %n, 2147483647 + %a = icmp slt i32 %x, %nn + %b = icmp sge i32 %x, 0 + %c = or i1 %a, %b ; should be: and + ret i1 %c +} +