diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -739,6 +739,53 @@ return Builder.CreateICmp(NewPred, Input, RangeEnd); } +// Fold (iszero(A & K1) ^ iszero(B & K2)) +// => iszero(((lshr A, log2(K1)) ^ (lshr B, log2(K2))) & 1) +Value *InstCombinerImpl::foldXorOfICmpsOfAndWithPow2(ICmpInst *LHS, + ICmpInst *RHS, + Instruction *CtxI) { + if (LHS->getPredicate() != CmpInst::ICMP_NE || + RHS->getPredicate() != CmpInst::ICMP_NE) + return nullptr; + + if (!match(LHS->getOperand(1), m_Zero()) || + !match(RHS->getOperand(1), m_Zero())) + return nullptr; + + if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) + return nullptr; + + Value *L1, *L2, *R1, *R2; + if (match(LHS->getOperand(0), m_And(m_Value(L1), m_Value(L2))) && + match(RHS->getOperand(0), m_And(m_Value(R1), m_Value(R2)))) { + if (L1 == R1) + return nullptr; + const APInt *L2Int, *R2Int; + if (match(L2, m_Power2(L2Int)) && match(R2, m_Power2(R2Int))) { + if (L2Int->eq(*R2Int)) { + Value *BitOperation = Builder.CreateXor(L1, R1); + Value *AndMask = Builder.CreateAnd( + BitOperation, + ConstantInt::get(BitOperation->getType(), L2Int->getZExtValue())); + return Builder.CreateICmp( + CmpInst::ICMP_NE, AndMask, + ConstantInt::getNullValue(AndMask->getType())); + } else { + Value *L1Lshr = Builder.CreateLShr( + L1, ConstantInt::get(L1->getType(), L2Int->logBase2())); + Value *R1Lshr = Builder.CreateLShr( + R1, ConstantInt::get(R1->getType(), R2Int->logBase2())); + Value *BitOperation = Builder.CreateXor(L1Lshr, R1Lshr); + Value *And = Builder.CreateAnd( + BitOperation, ConstantInt::get(BitOperation->getType(), 1)); + return Builder.CreateICmp(CmpInst::ICMP_NE, And, + ConstantInt::getNullValue(And->getType())); + } + } + } + return nullptr; +} + // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) Value *InstCombinerImpl::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, @@ -3555,6 +3602,9 @@ Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); + if (Value *V = foldXorOfICmpsOfAndWithPow2(LHS, RHS, &I)) + return V; + if (predicatesFoldable(PredL, PredR)) { if (LHS0 == RHS1 && LHS1 == RHS0) { std::swap(LHS0, LHS1); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -356,6 +356,9 @@ Value *foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1, ICmpInst *ICmp2, bool IsAnd); + Value *foldXorOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, + Instruction *CtxI); + /// Optimize (fcmp)&(fcmp) or (fcmp)|(fcmp). /// NOTE: Unlike most of instcombine, this returns a Value which should /// already be inserted into the function. diff --git a/llvm/test/Transforms/InstCombine/bit-checks.ll b/llvm/test/Transforms/InstCombine/bit-checks.ll --- a/llvm/test/Transforms/InstCombine/bit-checks.ll +++ b/llvm/test/Transforms/InstCombine/bit-checks.ll @@ -1317,11 +1317,11 @@ define i1 @xor_of_icmps_of_pow2(i32 %a, i32 %b) { ; CHECK-LABEL: @xor_of_icmps_of_pow2( -; CHECK-NEXT: [[AND_A:%.*]] = and i32 [[A:%.*]], 2 -; CHECK-NEXT: [[AND_A_ZERO:%.*]] = icmp ne i32 [[AND_A]], 0 -; CHECK-NEXT: [[AND_B:%.*]] = and i32 [[B:%.*]], 8 -; CHECK-NEXT: [[AND_B_ZERO:%.*]] = icmp ne i32 [[AND_B]], 0 -; CHECK-NEXT: [[RET:%.*]] = xor i1 [[AND_A_ZERO]], [[AND_B_ZERO]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[A:%.*]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[B:%.*]], 3 +; CHECK-NEXT: [[TMP3:%.*]] = xor i32 [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = and i32 [[TMP3]], 1 +; CHECK-NEXT: [[RET:%.*]] = icmp ne i32 [[TMP4]], 0 ; CHECK-NEXT: ret i1 [[RET]] ; %and_a = and i32 %a, 2 @@ -1334,11 +1334,10 @@ define i1 @xor_of_icmps_of_pow2_1(i32 %a, i32 %b) { ; CHECK-LABEL: @xor_of_icmps_of_pow2_1( -; CHECK-NEXT: [[AND_A:%.*]] = and i32 [[A:%.*]], 1 -; CHECK-NEXT: [[AND_A_ZERO:%.*]] = icmp ne i32 [[AND_A]], 0 -; CHECK-NEXT: [[AND_B:%.*]] = and i32 [[B:%.*]], 8 -; CHECK-NEXT: [[AND_B_ZERO:%.*]] = icmp ne i32 [[AND_B]], 0 -; CHECK-NEXT: [[RET:%.*]] = xor i1 [[AND_A_ZERO]], [[AND_B_ZERO]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[B:%.*]], 3 +; CHECK-NEXT: [[TMP2:%.*]] = xor i32 [[TMP1]], [[A:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 1 +; CHECK-NEXT: [[RET:%.*]] = icmp ne i32 [[TMP3]], 0 ; CHECK-NEXT: ret i1 [[RET]] ; %and_a = and i32 %a, 1 @@ -1368,11 +1367,9 @@ define i1 @xor_of_icmps_of_pow2_same_constant(i32 %a, i32 %b) { ; CHECK-LABEL: @xor_of_icmps_of_pow2_same_constant( -; CHECK-NEXT: [[AND_A:%.*]] = and i32 [[A:%.*]], 8 -; CHECK-NEXT: [[AND_A_ZERO:%.*]] = icmp ne i32 [[AND_A]], 0 -; CHECK-NEXT: [[AND_B:%.*]] = and i32 [[B:%.*]], 8 -; CHECK-NEXT: [[AND_B_ZERO:%.*]] = icmp ne i32 [[AND_B]], 0 -; CHECK-NEXT: [[RET:%.*]] = xor i1 [[AND_A_ZERO]], [[AND_B_ZERO]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i32 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 8 +; CHECK-NEXT: [[RET:%.*]] = icmp ne i32 [[TMP2]], 0 ; CHECK-NEXT: ret i1 [[RET]] ; %and_a = and i32 %a, 8