Index: llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -862,36 +862,63 @@ Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, bool JoinedByAnd, Instruction &CxtI) { - ICmpInst::Predicate Pred = LHS->getPredicate(); - if (Pred != RHS->getPredicate()) - return nullptr; - if (JoinedByAnd && Pred != ICmpInst::ICMP_NE) - return nullptr; - if (!JoinedByAnd && Pred != ICmpInst::ICMP_EQ) - return nullptr; + auto isCompareToZeroOrEquivalentForm = [](ICmpInst *Compare, bool JoinedByAnd, + InstCombiner::BuilderTy &Builder, + Value *&A, Value *&B) -> bool { + if (Compare->isEquality()) { + if (Compare->getPredicate() != + (JoinedByAnd ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ)) + return false; - // TODO support vector splats - ConstantInt *LHSC = dyn_cast(LHS->getOperand(1)); - ConstantInt *RHSC = dyn_cast(RHS->getOperand(1)); - if (!LHSC || !RHSC || !LHSC->isZero() || !RHSC->isZero()) - return nullptr; + // TODO support vector splats + ConstantInt *C = dyn_cast(Compare->getOperand(1)); + if (!C || !C->isZero()) + return false; - Value *A, *B, *C, *D; - if (match(LHS->getOperand(0), m_And(m_Value(A), m_Value(B))) && - match(RHS->getOperand(0), m_And(m_Value(C), m_Value(D)))) { - if (A == D || B == D) - std::swap(C, D); - if (B == C) - std::swap(A, B); - - if (A == C && - isKnownToBeAPowerOfTwo(B, false, 0, &CxtI) && - isKnownToBeAPowerOfTwo(D, false, 0, &CxtI)) { - Value *Mask = Builder.CreateOr(B, D); - Value *Masked = Builder.CreateAnd(A, Mask); - auto NewPred = JoinedByAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; - return Builder.CreateICmp(NewPred, Masked, Mask); + if (match(Compare->getOperand(0), m_And(m_Value(A), m_Value(B)))) + return true; + } else if (Compare->isSigned()) { + // Generalize iszero() check to equivalent form of (A l<< K) s>=/s< 0. + // In the middle-end, (A & (signbit l>> K)) ==/!= 0 can be folded into + // (A l<< K) s>=/s< 0. + if (Compare->getPredicate() != + (JoinedByAnd ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGT)) + return false; + + // TODO support vector splats + ConstantInt *C = dyn_cast(Compare->getOperand(1)); + if (!C || (JoinedByAnd ? !C->isZero() : !C->isMinusOne())) + return false; + + Value *B2; + if (match(Compare->getOperand(0), m_Shl(m_Value(A), m_Value(B2)))) { + Type *Ty = B2->getType(); + B = Builder.CreateLShr( + ConstantInt::get(Ty, APInt::getSignMask(Ty->getScalarSizeInBits())), + B2); + return true; + } } + + return false; + }; + + Value *A, *B, *C, *D; + if (!isCompareToZeroOrEquivalentForm(LHS, JoinedByAnd, Builder, A, B) || + !isCompareToZeroOrEquivalentForm(RHS, JoinedByAnd, Builder, C, D)) + return nullptr; + + if (A == D || B == D) + std::swap(C, D); + if (B == C) + std::swap(A, B); + + if (A == C && isKnownToBeAPowerOfTwo(B, false, 0, &CxtI) && + isKnownToBeAPowerOfTwo(D, false, 0, &CxtI)) { + Value *Mask = Builder.CreateOr(B, D); + Value *Masked = Builder.CreateAnd(A, Mask); + auto NewPred = JoinedByAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; + return Builder.CreateICmp(NewPred, Masked, Mask); } return nullptr; Index: llvm/test/Transforms/InstCombine/onehot_merge.ll =================================================================== --- llvm/test/Transforms/InstCombine/onehot_merge.ll +++ llvm/test/Transforms/InstCombine/onehot_merge.ll @@ -153,12 +153,11 @@ define i1 @foo1_and_signbit_lshr_without_shifting_signbit(i32 %k, i32 %c1, i32 %c2) { ; CHECK-LABEL: @foo1_and_signbit_lshr_without_shifting_signbit( ; CHECK-NEXT: [[T0:%.*]] = shl i32 1, [[C1:%.*]] -; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[K:%.*]] -; CHECK-NEXT: [[T2:%.*]] = icmp eq i32 [[T1]], 0 -; CHECK-NEXT: [[T3:%.*]] = shl i32 [[K]], [[C2:%.*]] -; CHECK-NEXT: [[T4:%.*]] = icmp sgt i32 [[T3]], -1 -; CHECK-NEXT: [[OR:%.*]] = or i1 [[T2]], [[T4]] -; CHECK-NEXT: ret i1 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 -2147483648, [[C2:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = or i32 [[T0]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], [[K:%.*]] +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne i32 [[TMP3]], [[TMP2]] +; CHECK-NEXT: ret i1 [[TMP4]] ; %t0 = shl i32 1, %c1 %t1 = and i32 %t0, %k @@ -172,12 +171,11 @@ define i1 @foo1_or_signbit_lshr_without_shifting_signbit(i32 %k, i32 %c1, i32 %c2) { ; CHECK-LABEL: @foo1_or_signbit_lshr_without_shifting_signbit( ; CHECK-NEXT: [[T0:%.*]] = shl i32 1, [[C1:%.*]] -; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[K:%.*]] -; CHECK-NEXT: [[T2:%.*]] = icmp ne i32 [[T1]], 0 -; CHECK-NEXT: [[T3:%.*]] = shl i32 [[K]], [[C2:%.*]] -; CHECK-NEXT: [[T4:%.*]] = icmp slt i32 [[T3]], 0 -; CHECK-NEXT: [[OR:%.*]] = and i1 [[T2]], [[T4]] -; CHECK-NEXT: ret i1 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 -2147483648, [[C2:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = or i32 [[T0]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], [[K:%.*]] +; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i32 [[TMP3]], [[TMP2]] +; CHECK-NEXT: ret i1 [[TMP4]] ; %t0 = shl i32 1, %c1 %t1 = and i32 %t0, %k