diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -4079,6 +4079,24 @@ Constant::getNullValue(X->getType()), I.getName()); } +static std::optional getKnownSign(Value *Op, Instruction *CxtI, + const DataLayout &DL, + AssumptionCache *AC, DominatorTree *DT, + KnownBits &Known) { + Known = computeKnownBits(Op, DL, 0, AC, CxtI, DT); + if (Known.isNonNegative()) + return false; + if (Known.isNegative()) + return true; + + Value *X, *Y; + if (match(Op, m_NSWSub(m_Value(X), m_Value(Y)))) + return isImpliedByDomCondition(ICmpInst::ICMP_SLT, X, Y, CxtI, DL); + + return isImpliedByDomCondition( + ICmpInst::ICMP_SLT, Op, Constant::getNullValue(Op->getType()), CxtI, DL); +} + /// Try to fold icmp (binop), X or icmp X, (binop). /// TODO: A large part of this logic is duplicated in InstSimplify's /// simplifyICmpWithBinOp(). We should be able to share that and avoid the code @@ -4375,6 +4393,111 @@ ConstantExpr::getNeg(RHSC)); } + { + bool Op0IsXor = false, Op1IsXor = false; + if (match(Op0, m_c_Xor(m_Specific(Op1), m_Value(A)))) + Op0IsXor = true; + else if (match(Op1, m_c_Xor(m_Specific(Op0), m_Value(A)))) + Op1IsXor = true; + if (Op0IsXor || Op1IsXor) { + + // icmp (X ^ Y_NonZero) u>= X --> icmp (X ^ Y_NonZero) u> X + // icmp (X ^ Y_NonZero) u<= X --> icmp (X ^ Y_NonZero) u< X + // icmp (X ^ Y_NonZero) s>= X --> icmp (X ^ Y_NonZero) s> X + // icmp (X ^ Y_NonZero) s<= X --> icmp (X ^ Y_NonZero) s< X + CmpInst::Predicate PredOut = Pred; + switch (Pred) { + case ICmpInst::ICMP_ULE: + PredOut = ICmpInst::ICMP_ULT; + break; + case ICmpInst::ICMP_UGE: + PredOut = ICmpInst::ICMP_UGT; + break; + case ICmpInst::ICMP_SLE: + PredOut = ICmpInst::ICMP_SLT; + break; + case ICmpInst::ICMP_SGE: + PredOut = ICmpInst::ICMP_SGT; + break; + default: + break; + } + if (PredOut != Pred && + isKnownNonZero(A, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(PredOut, Op0, Op1); + + // icmp (X ^ Y) u> X --> X & MSB(Y) == 0 + // icmp (X ^ Y) u< X --> X & MSB(Y) != 0 + // icmp (X ^ Pos_Y) s> X --> X & MSB(Pow2_Y) == 0 + // icmp (X ^ Pos_Y) s< X --> X & MSB(Pow2_Y) != 0 + switch (Pred) { + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: + PredOut = Op0IsXor ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: + PredOut = Op0IsXor ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ; + break; + default: + break; + } + + if (PredOut != Pred) { + Value *Xor = Op0IsXor ? Op0 : Op1; + Value *X = Op0IsXor ? Op1 : Op0; + bool Usable = true; + std::optional KnownOpt; + // If signed version we need extra for negative / positive. + if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT) { + KnownBits Known; + auto KnownSign = getKnownSign(A, &I, DL, &AC, &DT, Known); + KnownOpt = Known; + Usable = false; + if (KnownSign != std::nullopt) { + // icmp (X ^ Neg_Y) s> X --> X s< 0 + // icmp (X ^ Neg_Y) s< X --> X s>= 0 + if (*KnownSign /* true is Signed. */) { + // Negative power of 2 must be IntMin so this just a sign + // comparison. + if (Pred == ICmpInst::ICMP_SGT) + PredOut = Op0IsXor ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; + if (Pred == ICmpInst::ICMP_SLT) + PredOut = Op0IsXor ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_SLT; + return new ICmpInst(PredOut, X, + Constant::getNullValue(Op0->getType())); + } + Usable = !*KnownSign; + } + } + + if (Usable && Xor->hasOneUse()) { + // If known power of 2, then MSB == A. + if (isKnownToBeAPowerOfTwo(A, /*OrZero*/ true, 0, &I)) + return new ICmpInst(PredOut, Builder.CreateAnd(X, A), + Constant::getNullValue(Op0->getType())); + + // Otherwise try and find known MSB. + if (KnownOpt == std::nullopt) + KnownOpt = computeKnownBits(A, 0, &I); + unsigned MSBPos = KnownOpt->countMaxLeadingZeros(); + unsigned BitWidth = KnownOpt->getBitWidth(); + if (MSBPos == 0 || (MSBPos != BitWidth && + MSBPos == KnownOpt->countMinLeadingZeros())) { + Type *Ty = Op0->getType(); + return new ICmpInst( + PredOut, + Builder.CreateAnd( + X, + ConstantInt::get(Ty, APInt::getOneBitSet( + BitWidth, BitWidth - MSBPos - 1))), + Constant::getNullValue(Ty)); + } + } + } + } + } + { // Try to remove shared multiplier from comparison: // X * Z u{lt/le/gt/ge}/eq/ne Y * Z diff --git a/llvm/test/Transforms/InstCombine/icmp-of-xor-x.ll b/llvm/test/Transforms/InstCombine/icmp-of-xor-x.ll --- a/llvm/test/Transforms/InstCombine/icmp-of-xor-x.ll +++ b/llvm/test/Transforms/InstCombine/icmp-of-xor-x.ll @@ -9,7 +9,7 @@ ; CHECK-NEXT: [[YNZ:%.*]] = icmp ne i8 [[Y:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[X:%.*]], [[Y]] -; CHECK-NEXT: [[R:%.*]] = icmp uge i8 [[XOR]], [[X]] +; CHECK-NEXT: [[R:%.*]] = icmp ugt i8 [[XOR]], [[X]] ; CHECK-NEXT: ret i1 [[R]] ; %ynz = icmp ne i8 %y, 0 @@ -34,7 +34,7 @@ ; CHECK-LABEL: @xor_ule_2( ; CHECK-NEXT: [[Y:%.*]] = or <2 x i8> [[YY:%.*]], ; CHECK-NEXT: [[XOR:%.*]] = xor <2 x i8> [[Y]], [[X:%.*]] -; CHECK-NEXT: [[R:%.*]] = icmp ule <2 x i8> [[XOR]], [[X]] +; CHECK-NEXT: [[R:%.*]] = icmp ult <2 x i8> [[XOR]], [[X]] ; CHECK-NEXT: ret <2 x i1> [[R]] ; %y = or <2 x i8> %yy, @@ -49,7 +49,7 @@ ; CHECK-NEXT: [[YNZ:%.*]] = icmp ne i8 [[Y:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[X]], [[Y]] -; CHECK-NEXT: [[R:%.*]] = icmp sle i8 [[X]], [[XOR]] +; CHECK-NEXT: [[R:%.*]] = icmp slt i8 [[X]], [[XOR]] ; CHECK-NEXT: ret i1 [[R]] ; %x = add i8 %xx, %z @@ -62,9 +62,7 @@ define i1 @xor_sge(i8 %x, i8 %yy) { ; CHECK-LABEL: @xor_sge( -; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY:%.*]], -128 -; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[Y]], [[X:%.*]] -; CHECK-NEXT: [[R:%.*]] = icmp sge i8 [[XOR]], [[X]] +; CHECK-NEXT: [[R:%.*]] = icmp slt i8 [[X:%.*]], 0 ; CHECK-NEXT: ret i1 [[R]] ; %y = or i8 %yy, 128 @@ -76,10 +74,8 @@ define i1 @xor_ugt_2(i8 %xx, i8 %y, i8 %z) { ; CHECK-LABEL: @xor_ugt_2( ; CHECK-NEXT: [[X:%.*]] = add i8 [[XX:%.*]], [[Z:%.*]] -; CHECK-NEXT: [[YZ:%.*]] = and i8 [[Y:%.*]], 63 -; CHECK-NEXT: [[Y1:%.*]] = or i8 [[YZ]], 64 -; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[X]], [[Y1]] -; CHECK-NEXT: [[R:%.*]] = icmp ugt i8 [[X]], [[XOR]] +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X]], 64 +; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[TMP1]], 0 ; CHECK-NEXT: ret i1 [[R]] ; %x = add i8 %xx, %z @@ -92,8 +88,8 @@ define i1 @xor_ult(i8 %x) { ; CHECK-LABEL: @xor_ult( -; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[X:%.*]], 123 -; CHECK-NEXT: [[R:%.*]] = icmp ult i8 [[XOR]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], 64 +; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[TMP1]], 0 ; CHECK-NEXT: ret i1 [[R]] ; %xor = xor i8 %x, 123 @@ -103,10 +99,8 @@ define <2 x i1> @xor_sgt(<2 x i8> %x, <2 x i8> %y) { ; CHECK-LABEL: @xor_sgt( -; CHECK-NEXT: [[YZ:%.*]] = and <2 x i8> [[Y:%.*]], -; CHECK-NEXT: [[Y1:%.*]] = or <2 x i8> [[YZ]], -; CHECK-NEXT: [[XOR:%.*]] = xor <2 x i8> [[Y1]], [[X:%.*]] -; CHECK-NEXT: [[R:%.*]] = icmp sgt <2 x i8> [[XOR]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X:%.*]], +; CHECK-NEXT: [[R:%.*]] = icmp eq <2 x i8> [[TMP1]], zeroinitializer ; CHECK-NEXT: ret <2 x i1> [[R]] ; %yz = and <2 x i8> %y, @@ -133,8 +127,8 @@ define i1 @xor_slt_2(i8 %x, i8 %y, i8 %z) { ; CHECK-LABEL: @xor_slt_2( -; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[X:%.*]], 88 -; CHECK-NEXT: [[R:%.*]] = icmp sgt i8 [[XOR]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], 64 +; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[TMP1]], 0 ; CHECK-NEXT: ret i1 [[R]] ; %xor = xor i8 %x, 88 @@ -145,9 +139,7 @@ define <2 x i1> @xor_sgt_intmin_2(<2 x i8> %xx, <2 x i8> %yy, <2 x i8> %z) { ; CHECK-LABEL: @xor_sgt_intmin_2( ; CHECK-NEXT: [[X:%.*]] = add <2 x i8> [[XX:%.*]], [[Z:%.*]] -; CHECK-NEXT: [[Y:%.*]] = or <2 x i8> [[YY:%.*]], -; CHECK-NEXT: [[XOR:%.*]] = xor <2 x i8> [[X]], [[Y]] -; CHECK-NEXT: [[R:%.*]] = icmp sgt <2 x i8> [[X]], [[XOR]] +; CHECK-NEXT: [[R:%.*]] = icmp sgt <2 x i8> [[X]], ; CHECK-NEXT: ret <2 x i1> [[R]] ; %x = add <2 x i8> %xx, %z @@ -165,8 +157,7 @@ ; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i1 [ [[R:%.*]], [[NEG]] ], [ false, [[POS]] ] ; CHECK-NEXT: ret i1 [[COMMON_RET_OP]] ; CHECK: neg: -; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[C]], [[X:%.*]] -; CHECK-NEXT: [[R]] = icmp slt i8 [[XOR]], [[X]] +; CHECK-NEXT: [[R]] = icmp sgt i8 [[X:%.*]], -1 ; CHECK-NEXT: br label [[COMMON_RET:%.*]] ; CHECK: pos: ; CHECK-NEXT: tail call void @barrier()