Index: lib/Transforms/InstCombine/InstCombineAndOrXor.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -741,6 +741,34 @@ return nullptr; } +// For an ICMP where RHS is zero, we want to check if the ICMP is equivalent to +// comparing a group of bits in an integer value against zero. If yes we return +// (true, X, Mask). X is a boolean to identify condition code. Mask identifies +// BitGroup. This is mostly when LHS is 'and' of an integer with a mask, but +// there are other cases as well. Like when CC=SLT, where effectively we check +// to see if sign bit is one or not. +static std::tuple +IsAnyBitSet(Value *LHS, ICmpInst::Predicate CC) { + auto Inst = dyn_cast(LHS); + + if (Inst && Inst->getOpcode() == Instruction::And) { + auto Mask = dyn_cast(Inst->getOperand(1)); + if (!Mask) + return std::make_tuple(false, false, nullptr); + + switch (CC) { + default: break; + case ICmpInst::ICMP_EQ: + return std::make_tuple(true, false, &(Mask->getValue())); + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_UGT: + return std::make_tuple(true, true, &(Mask->getValue())); + } + } + + return std::make_tuple(false, false, nullptr); +} + /// Try to fold a signed range checked with lower bound 0 to an unsigned icmp. /// Example: (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n /// If \p Inverted is true then the check is for the inverted range, e.g. @@ -797,6 +825,30 @@ return Builder->CreateICmp(NewPred, Input, RangeEnd); } +Value *InstCombiner::FoldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) { + + Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0); + ConstantInt *LHSCst = dyn_cast(LHS->getOperand(1)); + ConstantInt *RHSCst = dyn_cast(RHS->getOperand(1)); + if (!LHSCst || !RHSCst) return nullptr; + ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); + + if (RHSCst->isZero() && LHSCst->isZero()) { + + bool IsBitCheck1, IsBitCheck2, CheckSet1, CheckSet2; + const APInt *Mask1, *Mask2; + std::tie(IsBitCheck1, CheckSet1, Mask1) = IsAnyBitSet(Val, LHSCC); + std::tie(IsBitCheck2, CheckSet2, Mask2) = IsAnyBitSet(Val2, RHSCC); + if (IsBitCheck1 && IsBitCheck2 && CheckSet1 == CheckSet2 && + Mask1->getBitWidth() == Mask2->getBitWidth() && + *Mask1 == *Mask2 && + Mask1->isPowerOf2()) { + return Builder->CreateICmp(ICmpInst::ICMP_NE, Val2, Val); + } + } + return nullptr; +} + /// Fold (icmp)&(icmp) if possible. Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); @@ -879,6 +931,27 @@ } } + // E.g. (icmp eq x, 0) & (icmp ne y, 0) => icmp ult x, y if (x = and a, mask1) + // and (y = and b, mask2) where both masks have one bit and mask1 >= mask2 + if (RHSCst->isZero() && LHSCst->isZero()) { + bool IsBitCheck1, IsBitCheck2, CheckSet1, CheckSet2; + const APInt *Mask1, *Mask2; + std::tie(IsBitCheck1, CheckSet1, Mask1) = IsAnyBitSet(Val, LHSCC); + std::tie(IsBitCheck2, CheckSet2, Mask2) = IsAnyBitSet(Val2, RHSCC); + + if (IsBitCheck1 && IsBitCheck2 && (CheckSet1 != CheckSet2) && + Mask1->getBitWidth() == Mask2->getBitWidth()) { + if (!CheckSet1 && + Mask1->countTrailingZeros() >= + Mask2->getBitWidth() - Mask2->countLeadingZeros() - 1) + return Builder->CreateICmp(ICmpInst::ICMP_ULT, Val, Val2); + else if (!CheckSet2 && + Mask2->countTrailingZeros() >= + Mask1->getBitWidth() - Mask1->countLeadingZeros() - 1) + return Builder->CreateICmp(ICmpInst::ICMP_ULT, Val2, Val); + } + } + // From here on, we only handle: // (icmp1 A, C1) & (icmp2 A, C2) --> something simpler. if (Val != Val2) return nullptr; @@ -2714,9 +2787,16 @@ match(Op1, m_Not(m_Specific(A)))) return BinaryOperator::CreateNot(Builder->CreateAnd(A, B)); - // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B) if (ICmpInst *RHS = dyn_cast(I.getOperand(1))) - if (ICmpInst *LHS = dyn_cast(I.getOperand(0))) + if (ICmpInst *LHS = dyn_cast(I.getOperand(0))) { + + // E.g. if we have xor (icmp eq %A, 0), (icmp eq %B, 0) + // and we know both A and B are either 8 (power of 2) or 0 + // we can simplify to (icmp ne A, B) + if (Value *Res = FoldXorOfICmps(LHS, RHS)) + return replaceInstUsesWith(I, Res); + + // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B) if (PredicatesFoldable(LHS->getPredicate(), RHS->getPredicate())) { if (LHS->getOperand(0) == RHS->getOperand(1) && LHS->getOperand(1) == RHS->getOperand(0)) @@ -2731,6 +2811,7 @@ Builder)); } } + } if (Instruction *CastedXor = foldCastedBitwiseLogic(I)) return CastedXor; Index: lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- lib/Transforms/InstCombine/InstCombineInternal.h +++ lib/Transforms/InstCombine/InstCombineInternal.h @@ -225,6 +225,7 @@ Instruction *visitFDiv(BinaryOperator &I); Value *simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, bool Inverted); Value *FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS); + Value *FoldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS); Value *FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Instruction *visitAnd(BinaryOperator &I); Value *FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI); Index: test/Transforms/InstCombine/and-or-icmps.ll =================================================================== --- test/Transforms/InstCombine/and-or-icmps.ll +++ test/Transforms/InstCombine/and-or-icmps.ll @@ -51,3 +51,164 @@ ret i1 %tmp1042 } +define i1 @test2(i32 %a, i32 %b) { + +; CHECK-LABEL: @test2( +; CHECK-NEXT: val1 = and i32 %a, 8 +; CHECK-NEXT: val2 = and i32 %b, 8 +; CHECK-NEXT: [[TEST2TMP:%.*]] = icmp ult i32 %val2, %val1 +; CHECK-NEXT: ret i1 [[TEST2TMP]] + + %val1 = and i32 %a, 8 + %val2 = and i32 %b, 8 + %cmp.a = icmp ne i32 %val1, 0 + %cmp.b = icmp eq i32 %val2, 0 + %and = and i1 %cmp.b, %cmp.a + ret i1 %and +} + +define i1 @test3(i32 %a, i32 %b) { + +; CHECK-LABEL: @test3( +; CHECK-NEXT: val1 = and i32 %a, 8 +; CHECK-NEXT: val2 = and i32 %b, 8 +; CHECK-NEXT: [[TEST3TMP:%.*]] = icmp ult i32 %val2, %val1 +; CHECK-NEXT: ret i1 [[TEST3TMP]] + + %val1 = and i32 %a, 8 + %val2 = and i32 %b, 8 + %cmp.a = icmp ne i32 %val1, 0 + %cmp.b = icmp eq i32 %val2, 0 + %and = and i1 %cmp.a, %cmp.b + ret i1 %and +} + +define i1 @test4(i32 %a, i32 %b) { + +; CHECK-LABEL: @test4( +; CHECK-NEXT: val1 = and i32 %a, 15 +; CHECK-NEXT: val2 = and i32 %b, 24 +; CHECK-NEXT: [[TEST4TMP:%.*]] = icmp ult i32 %val2, %val1 +; CHECK-NEXT: ret i1 [[TEST4TMP]] + + %val1 = and i32 %a, 15 + %val2 = and i32 %b, 24 + %cmp.a = icmp ne i32 %val1, 0 + %cmp.b = icmp eq i32 %val2, 0 + %and = and i1 %cmp.a, %cmp.b + ret i1 %and +} + +define i1 @test5(i32 %a, i32 %b) { + +; CHECK-LABEL: @test5( +; CHECK-NEXT: val1 = and i32 %a, 15 +; CHECK-NEXT: val2 = and i32 %b, 24 +; CHECK-NEXT: [[TEST5TMP:%.*]] = icmp ult i32 %val2, %val1 +; CHECK-NEXT: ret i1 [[TEST5TMP]] + + %val1 = and i32 %a, 15 + %val2 = and i32 %b, 24 + %cmp.a = icmp ne i32 %val1, 0 + %cmp.b = icmp eq i32 %val2, 0 + %and = and i1 %cmp.b, %cmp.a + ret i1 %and +} + +define i1 @test6(i32 %a, i32 %b) { + +; CHECK-LABEL: @test6( +; CHECK-NOT: icmp ult i32 +; CHECK: ret i1 + + %val1 = and i32 %a, 16 + %val2 = and i32 %b, 24 + %cmp.a = icmp ne i32 %val1, 0 + %cmp.b = icmp eq i32 %val2, 0 + %and = and i1 %cmp.b, %cmp.a + ret i1 %and +} + +define i1 @test7(i16 %a, i32 %b) { + +; CHECK-LABEL: @test7( +; CHECK-NOT: icmp ult +; CHECK: ret i1 + + %val1 = and i16 %a, 15 + %val2 = and i32 %b, 24 + %cmp.a = icmp ne i16 %val1, 0 + %cmp.b = icmp eq i32 %val2, 0 + %and = and i1 %cmp.b, %cmp.a + ret i1 %and +} + +define i1 @test8(i32 %a, i32 %b) { + +; CHECK-LABEL: @test8( +; CHECK-NEXT: [[TEST8TMP1:%.*]] = xor i32 %a, %b +; CHECK-NEXT: [[TEST8TMP2:%.*]] = and i32 [[TEST8TMP1]], 8 +; CHECK-NEXT: icmp ne i32 [[TEST8TMP2]], 0 +; CHECK-NEXT: ret i1 + + %val1 = and i32 %a, 8 + %val2 = and i32 %b, 8 + %cmp.a = icmp ne i32 %val1, 0 + %cmp.b = icmp ne i32 %val2, 0 + %and = xor i1 %cmp.b, %cmp.a + ret i1 %and +} + +define i1 @test9(i32 %a, i32 %b) { + +; CHECK-LABEL: @test9( +; CHECK-NEXT: and +; CHECK-NEXT: and +; CHECK-NEXT: icmp ne +; CHECK-NEXT: icmp ne +; CHECK-NEXT: xor +; CHECK-NEXT: ret i1 + + %val1 = and i32 %a, 24 + %val2 = and i32 %b, 24 + %cmp.a = icmp ne i32 %val1, 0 + %cmp.b = icmp ne i32 %val2, 0 + %and = xor i1 %cmp.b, %cmp.a + ret i1 %and +} + +define i1 @test10(i32 %a, i32 %b) { + +; CHECK-LABEL: @test10( +; CHECK-NEXT: [[TEST10TMP1:%.*]] = xor i32 %a, %b +; CHECK-NEXT: [[TEST10TMP2:%.*]] = and i32 [[TEST10TMP1]], 8 +; CHECK-NEXT: icmp ne i32 [[TEST10TMP2]], 0 +; CHECK-NEXT: ret i1 + + %val1 = and i32 %a, 8 + %val2 = and i32 %b, 8 + %cmp.a = icmp eq i32 %val1, 0 + %cmp.b = icmp eq i32 %val2, 0 + %and = xor i1 %cmp.b, %cmp.a + ret i1 %and +} + +define i1 @test11(i16 %a, i32 %b) { + +; CHECK-LABEL: @test11( +; CHECK-NEXT: and +; CHECK-NEXT: and +; CHECK-NEXT: icmp ne +; CHECK-NEXT: icmp ne +; CHECK-NEXT: xor +; CHECK-NEXT: ret i1 + + %val1 = and i16 %a, 8 + %val2 = and i32 %b, 8 + %cmp.a = icmp ne i16 %val1, 0 + %cmp.b = icmp ne i32 %val2, 0 + %and = xor i1 %cmp.b, %cmp.a + ret i1 %and +} + +