diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -4085,6 +4085,30 @@ Constant::getNullValue(X->getType()), I.getName()); } +static Instruction *foldICmpXorXX(ICmpInst &I, const SimplifyQuery &Q, + InstCombinerImpl &IC) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A; + // Normalize xor operand as operand 0. + CmpInst::Predicate Pred = I.getPredicate(); + if (match(Op1, m_c_Xor(m_Specific(Op0), m_Value()))) { + std::swap(Op0, Op1); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + if (!match(Op0, m_c_Xor(m_Specific(Op1), m_Value(A)))) + return nullptr; + + // icmp (X ^ Y_NonZero) u>= X --> icmp (X ^ Y_NonZero) u> X + // icmp (X ^ Y_NonZero) u<= X --> icmp (X ^ Y_NonZero) u< X + // icmp (X ^ Y_NonZero) s>= X --> icmp (X ^ Y_NonZero) s> X + // icmp (X ^ Y_NonZero) s<= X --> icmp (X ^ Y_NonZero) s< X + CmpInst::Predicate PredOut = CmpInst::getStrictPredicate(Pred); + if (PredOut != Pred && + isKnownNonZero(A, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(PredOut, Op0, Op1); + + return nullptr; +} + /// Try to fold icmp (binop), X or icmp X, (binop). /// TODO: A large part of this logic is duplicated in InstSimplify's /// simplifyICmpWithBinOp(). We should be able to share that and avoid the code @@ -4381,6 +4405,9 @@ ConstantExpr::getNeg(RHSC)); } + if (Instruction * R = foldICmpXorXX(I, Q, *this)) + return R; + { // Try to remove shared multiplier from comparison: // X * Z u{lt/le/gt/ge}/eq/ne Y * Z diff --git a/llvm/test/Transforms/InstCombine/icmp-of-xor-x.ll b/llvm/test/Transforms/InstCombine/icmp-of-xor-x.ll --- a/llvm/test/Transforms/InstCombine/icmp-of-xor-x.ll +++ b/llvm/test/Transforms/InstCombine/icmp-of-xor-x.ll @@ -9,7 +9,7 @@ ; CHECK-NEXT: [[YNZ:%.*]] = icmp ne i8 [[Y:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[X:%.*]], [[Y]] -; CHECK-NEXT: [[R:%.*]] = icmp uge i8 [[XOR]], [[X]] +; CHECK-NEXT: [[R:%.*]] = icmp ugt i8 [[XOR]], [[X]] ; CHECK-NEXT: ret i1 [[R]] ; %ynz = icmp ne i8 %y, 0 @@ -34,7 +34,7 @@ ; CHECK-LABEL: @xor_ule_2( ; CHECK-NEXT: [[Y:%.*]] = or <2 x i8> [[YY:%.*]], ; CHECK-NEXT: [[XOR:%.*]] = xor <2 x i8> [[Y]], [[X:%.*]] -; CHECK-NEXT: [[R:%.*]] = icmp ule <2 x i8> [[XOR]], [[X]] +; CHECK-NEXT: [[R:%.*]] = icmp ult <2 x i8> [[XOR]], [[X]] ; CHECK-NEXT: ret <2 x i1> [[R]] ; %y = or <2 x i8> %yy, @@ -49,7 +49,7 @@ ; CHECK-NEXT: [[YNZ:%.*]] = icmp ne i8 [[Y:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[X]], [[Y]] -; CHECK-NEXT: [[R:%.*]] = icmp sle i8 [[X]], [[XOR]] +; CHECK-NEXT: [[R:%.*]] = icmp sgt i8 [[XOR]], [[X]] ; CHECK-NEXT: ret i1 [[R]] ; %x = add i8 %xx, %z @@ -64,7 +64,7 @@ ; CHECK-LABEL: @xor_sge( ; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY:%.*]], -128 ; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[Y]], [[X:%.*]] -; CHECK-NEXT: [[R:%.*]] = icmp sge i8 [[XOR]], [[X]] +; CHECK-NEXT: [[R:%.*]] = icmp sgt i8 [[XOR]], [[X]] ; CHECK-NEXT: ret i1 [[R]] ; %y = or i8 %yy, 128