Index: llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -862,36 +862,87 @@ Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, bool JoinedByAnd, Instruction &CxtI) { - ICmpInst::Predicate Pred = LHS->getPredicate(); - if (Pred != RHS->getPredicate()) - return nullptr; - if (JoinedByAnd && Pred != ICmpInst::ICMP_NE) - return nullptr; - if (!JoinedByAnd && Pred != ICmpInst::ICMP_EQ) + // Exit early if neither LHS nor RHS is equality comparism. Generalize + // isZero() check will decompose (A l<< K) s>=/s< 0 into (A & (signbit l>> K)) + // ==/!= 0, thus introduce additional instruction of signbit shift. Do not try + // to decompose both sides. + if (!LHS->isEquality() && !RHS->isEquality()) return nullptr; - // TODO support vector splats - ConstantInt *LHSC = dyn_cast(LHS->getOperand(1)); - ConstantInt *RHSC = dyn_cast(RHS->getOperand(1)); - if (!LHSC || !RHSC || !LHSC->isZero() || !RHSC->isZero()) - return nullptr; + auto isCompareToZeroOrEquivalentForm = [](ICmpInst *Compare, bool JoinedByAnd, + InstCombiner::BuilderTy &Builder, + Value *&A, Value *&B) -> bool { + if (!Compare->hasOneUse()) + return false; - Value *A, *B, *C, *D; - if (match(LHS->getOperand(0), m_And(m_Value(A), m_Value(B))) && - match(RHS->getOperand(0), m_And(m_Value(C), m_Value(D)))) { - if (A == D || B == D) - std::swap(C, D); - if (B == C) - std::swap(A, B); - - if (A == C && - isKnownToBeAPowerOfTwo(B, false, 0, &CxtI) && - isKnownToBeAPowerOfTwo(D, false, 0, &CxtI)) { - Value *Mask = Builder.CreateOr(B, D); - Value *Masked = Builder.CreateAnd(A, Mask); - auto NewPred = JoinedByAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; - return Builder.CreateICmp(NewPred, Masked, Mask); + if (Compare->isEquality()) { + if (Compare->getPredicate() != + (JoinedByAnd ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ)) + return false; + + // TODO support vector splats + ConstantInt *C = dyn_cast(Compare->getOperand(1)); + if (!C || !C->isZero()) + return false; + + return match(Compare->getOperand(0), m_And(m_Value(A), m_Value(B))); + } else if (Compare->isSigned()) { + // Generalize iszero() check to equivalent form of (A l<< K) s>=/s< 0. + // In the middle-end, (A & (signbit l>> K)) ==/!= 0 can be folded into + // (A l<< K) s>=/s< 0. + if (Compare->getPredicate() != + (JoinedByAnd ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGT)) + return false; + + // TODO support vector splats + ConstantInt *C = dyn_cast(Compare->getOperand(1)); + if (!C || (JoinedByAnd ? !C->isZero() : !C->isMinusOne())) + return false; + + return match(Compare->getOperand(0), + m_OneUse(m_Shl(m_Value(A), m_Value(B)))); } + + return false; + }; + + Value *A, *B, *C, *D; + if (!isCompareToZeroOrEquivalentForm(LHS, JoinedByAnd, Builder, A, B) || + !isCompareToZeroOrEquivalentForm(RHS, JoinedByAnd, Builder, C, D)) + return nullptr; + + if (A == D || B == D) + std::swap(C, D); + if (B == C) + std::swap(A, B); + + if (A == C) { + if ((LHS->isEquality() && !isKnownToBeAPowerOfTwo(B, false, 0, &CxtI)) || + (RHS->isEquality() && !isKnownToBeAPowerOfTwo(D, false, 0, &CxtI))) + return nullptr; + + auto getSelfOrCreateSignbitShift = [](ICmpInst *Compare, + InstCombiner::BuilderTy &Builder, + Value *B) -> Value * { + // Partially decompose (A l<< K) s>=/s< 0 into (A & (signbit l>> K)) ==/!= + // 0 by creating the signbit shift. + if (Compare->isSigned()) + return Builder.CreateLShr( + ConstantInt::get( + B->getType(), + APInt::getSignMask(B->getType()->getScalarSizeInBits())), + B); + else + return B; + }; + + Value *Mask = + Builder.CreateOr(getSelfOrCreateSignbitShift(LHS, Builder, B), + getSelfOrCreateSignbitShift(RHS, Builder, D)); + Value *Masked = Builder.CreateAnd(A, Mask); + auto NewPred = JoinedByAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; + + return Builder.CreateICmp(NewPred, Masked, Mask); } return nullptr; Index: llvm/test/Transforms/InstCombine/onehot_merge.ll =================================================================== --- llvm/test/Transforms/InstCombine/onehot_merge.ll +++ llvm/test/Transforms/InstCombine/onehot_merge.ll @@ -153,12 +153,11 @@ define i1 @foo1_and_signbit_lshr_without_shifting_signbit(i32 %k, i32 %c1, i32 %c2) { ; CHECK-LABEL: @foo1_and_signbit_lshr_without_shifting_signbit( ; CHECK-NEXT: [[T0:%.*]] = shl i32 1, [[C1:%.*]] -; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[K:%.*]] -; CHECK-NEXT: [[T2:%.*]] = icmp eq i32 [[T1]], 0 -; CHECK-NEXT: [[T3:%.*]] = shl i32 [[K]], [[C2:%.*]] -; CHECK-NEXT: [[T4:%.*]] = icmp sgt i32 [[T3]], -1 -; CHECK-NEXT: [[OR:%.*]] = or i1 [[T2]], [[T4]] -; CHECK-NEXT: ret i1 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 -2147483648, [[C2:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = or i32 [[T0]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], [[K:%.*]] +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne i32 [[TMP3]], [[TMP2]] +; CHECK-NEXT: ret i1 [[TMP4]] ; %t0 = shl i32 1, %c1 %t1 = and i32 %t0, %k @@ -172,12 +171,11 @@ define i1 @foo1_or_signbit_lshr_without_shifting_signbit(i32 %k, i32 %c1, i32 %c2) { ; CHECK-LABEL: @foo1_or_signbit_lshr_without_shifting_signbit( ; CHECK-NEXT: [[T0:%.*]] = shl i32 1, [[C1:%.*]] -; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[K:%.*]] -; CHECK-NEXT: [[T2:%.*]] = icmp ne i32 [[T1]], 0 -; CHECK-NEXT: [[T3:%.*]] = shl i32 [[K]], [[C2:%.*]] -; CHECK-NEXT: [[T4:%.*]] = icmp slt i32 [[T3]], 0 -; CHECK-NEXT: [[OR:%.*]] = and i1 [[T2]], [[T4]] -; CHECK-NEXT: ret i1 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 -2147483648, [[C2:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = or i32 [[T0]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], [[K:%.*]] +; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i32 [[TMP3]], [[TMP2]] +; CHECK-NEXT: ret i1 [[TMP4]] ; %t0 = shl i32 1, %c1 %t1 = and i32 %t0, %k @@ -245,7 +243,7 @@ ret i1 %or } -; Should not fold +; Expect to fold define i1 @foo1_and_extra_use_and(i32 %k, i32 %c1, i32 %c2, i32* %p) { ; CHECK-LABEL: @foo1_and_extra_use_and( ; CHECK-NEXT: [[T0:%.*]] = shl i32 1, [[C1:%.*]] @@ -276,10 +274,10 @@ ; CHECK-NEXT: [[T2:%.*]] = and i32 [[T0]], [[K:%.*]] ; CHECK-NEXT: [[T3:%.*]] = icmp eq i32 [[T2]], 0 ; CHECK-NEXT: store i1 [[T3]], i1* [[P:%.*]], align 1 -; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[T0]], [[T1]] -; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], [[K]] -; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i32 [[TMP2]], [[TMP1]] -; CHECK-NEXT: ret i1 [[TMP3]] +; CHECK-NEXT: [[T4:%.*]] = and i32 [[T1]], [[K]] +; CHECK-NEXT: [[T5:%.*]] = icmp eq i32 [[T4]], 0 +; CHECK-NEXT: [[OR:%.*]] = or i1 [[T3]], [[T5]] +; CHECK-NEXT: ret i1 [[OR]] ; %t0 = shl i32 1, %c1 %t1 = shl i32 1, %c2 @@ -314,7 +312,7 @@ ret i1 %or } -; Should not fold +; Expect to fold define i1 @foo1_and_extra_use_and2(i32 %k, i32 %c1, i32 %c2, i32* %p) { ; CHECK-LABEL: @foo1_and_extra_use_and2( ; CHECK-NEXT: [[T0:%.*]] = shl i32 1, [[C1:%.*]] @@ -342,13 +340,13 @@ ; CHECK-LABEL: @foo1_and_extra_use_cmp2( ; CHECK-NEXT: [[T0:%.*]] = shl i32 1, [[C1:%.*]] ; CHECK-NEXT: [[T1:%.*]] = shl i32 1, [[C2:%.*]] -; CHECK-NEXT: [[T4:%.*]] = and i32 [[T1]], [[K:%.*]] +; CHECK-NEXT: [[T2:%.*]] = and i32 [[T0]], [[K:%.*]] +; CHECK-NEXT: [[T3:%.*]] = icmp eq i32 [[T2]], 0 +; CHECK-NEXT: [[T4:%.*]] = and i32 [[T1]], [[K]] ; CHECK-NEXT: [[T5:%.*]] = icmp eq i32 [[T4]], 0 ; CHECK-NEXT: store i1 [[T5]], i1* [[P:%.*]], align 1 -; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[T0]], [[T1]] -; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], [[K]] -; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i32 [[TMP2]], [[TMP1]] -; CHECK-NEXT: ret i1 [[TMP3]] +; CHECK-NEXT: [[OR:%.*]] = or i1 [[T3]], [[T5]] +; CHECK-NEXT: ret i1 [[OR]] ; %t0 = shl i32 1, %c1 %t1 = shl i32 1, %c2 @@ -367,12 +365,11 @@ ; CHECK-LABEL: @foo1_and_signbit_lshr_without_shifting_signbit_extra_use_shl1( ; CHECK-NEXT: [[T0:%.*]] = shl i32 1, [[C1:%.*]] ; CHECK-NEXT: store i32 [[T0]], i32* [[P:%.*]], align 4 -; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[K:%.*]] -; CHECK-NEXT: [[T2:%.*]] = icmp eq i32 [[T1]], 0 -; CHECK-NEXT: [[T3:%.*]] = shl i32 [[K]], [[C2:%.*]] -; CHECK-NEXT: [[T4:%.*]] = icmp sgt i32 [[T3]], -1 -; CHECK-NEXT: [[OR:%.*]] = or i1 [[T2]], [[T4]] -; CHECK-NEXT: ret i1 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 -2147483648, [[C2:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = or i32 [[T0]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], [[K:%.*]] +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne i32 [[TMP3]], [[TMP2]] +; CHECK-NEXT: ret i1 [[TMP4]] ; %t0 = shl i32 1, %c1 store i32 %t0, i32* %p ; extra use of shl @@ -384,17 +381,17 @@ ret i1 %or } -; Not fold +; Expect to fold define i1 @foo1_and_signbit_lshr_without_shifting_signbit_extra_use_and(i32 %k, i32 %c1, i32 %c2, i32* %p) { ; CHECK-LABEL: @foo1_and_signbit_lshr_without_shifting_signbit_extra_use_and( ; CHECK-NEXT: [[T0:%.*]] = shl i32 1, [[C1:%.*]] ; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[K:%.*]] ; CHECK-NEXT: store i32 [[T1]], i32* [[P:%.*]], align 4 -; CHECK-NEXT: [[T2:%.*]] = icmp eq i32 [[T1]], 0 -; CHECK-NEXT: [[T3:%.*]] = shl i32 [[K]], [[C2:%.*]] -; CHECK-NEXT: [[T4:%.*]] = icmp sgt i32 [[T3]], -1 -; CHECK-NEXT: [[OR:%.*]] = or i1 [[T2]], [[T4]] -; CHECK-NEXT: ret i1 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 -2147483648, [[C2:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = or i32 [[T0]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], [[K]] +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne i32 [[TMP3]], [[TMP2]] +; CHECK-NEXT: ret i1 [[TMP4]] ; %t0 = shl i32 1, %c1 %t1 = and i32 %t0, %k