diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1064,7 +1064,7 @@ V = Builder.CreateLShr(V, P.StartBit); Type *TruncTy = V->getType()->getWithNewBitWidth(P.NumBits); if (TruncTy != V->getType()) - V = Builder.CreateTrunc(V, TruncTy); + V = Builder.CreateZExtOrTrunc(V, TruncTy); return V; } @@ -1077,13 +1077,57 @@ return nullptr; CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; - if (Cmp0->getPredicate() != Pred || Cmp1->getPredicate() != Pred) + auto MatchPred = [&](ICmpInst *Cmp) -> std::pair { + if (Pred == Cmp->getPredicate()) + return {true, nullptr}; + + const APInt *C; + // (icmp eq (lshr x, C), (lshr y, C)) gets optimized to: + // (icmp ult (xor x, y), 1 << C) so also look for that. + if (Pred == CmpInst::ICMP_EQ && Cmp->getPredicate() == CmpInst::ICMP_ULT) + return {match(Cmp->getOperand(1), m_APInt(C)) && C->isPowerOf2() && + match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value())), + C}; + + // (icmp ne (lshr x, C), (lshr y, C)) gets optimized to: + // (icmp uge (xor x, y), (1 << C) - 1) so also look for that. + if (Pred == CmpInst::ICMP_NE && Cmp->getPredicate() == CmpInst::ICMP_UGT) + return {match(Cmp->getOperand(1), m_APInt(C)) && C->isMask() && + match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value())), + C}; + + return {false, nullptr}; + }; + + auto GetMatchPart = [&](std::pair MatchResult, + ICmpInst *Cmp, + unsigned OpNo) -> std::optional { + // Normal IntPart + if (MatchResult.second == nullptr) + return matchIntPart(Cmp->getOperand(OpNo)); + + // We have one of the ult/ugt patterns. + unsigned From; + const APInt *C = MatchResult.second; + if (Pred == CmpInst::ICMP_NE) + From = C->popcount(); + else + From = (*C - 1).popcount(); + Instruction *I = cast(Cmp->getOperand(0)); + return {{I->getOperand(OpNo), From, + Cmp->getOperand(0)->getType()->getScalarSizeInBits()}}; + }; + + auto Cmp0Match = MatchPred(Cmp0); + auto Cmp1Match = MatchPred(Cmp1); + if (!Cmp0Match.first || !Cmp1Match.first) return nullptr; - std::optional L0 = matchIntPart(Cmp0->getOperand(0)); - std::optional R0 = matchIntPart(Cmp0->getOperand(1)); - std::optional L1 = matchIntPart(Cmp1->getOperand(0)); - std::optional R1 = matchIntPart(Cmp1->getOperand(1)); + std::optional L0 = GetMatchPart(Cmp0Match, Cmp0, 0); + std::optional R0 = GetMatchPart(Cmp0Match, Cmp0, 1); + std::optional L1 = GetMatchPart(Cmp1Match, Cmp1, 0); + std::optional R1 = GetMatchPart(Cmp1Match, Cmp1, 1); + if (!L0 || !R0 || !L1 || !R1) return nullptr; diff --git a/llvm/test/Transforms/InstCombine/eq-of-parts.ll b/llvm/test/Transforms/InstCombine/eq-of-parts.ll --- a/llvm/test/Transforms/InstCombine/eq-of-parts.ll +++ b/llvm/test/Transforms/InstCombine/eq-of-parts.ll @@ -1336,12 +1336,7 @@ define i1 @eq_optimized_highbits_cmp(i32 %x, i32 %y) { ; CHECK-LABEL: @eq_optimized_highbits_cmp( -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]] -; CHECK-NEXT: [[CMP_HI:%.*]] = icmp ult i32 [[XOR]], 33554432 -; CHECK-NEXT: [[TX:%.*]] = trunc i32 [[X]] to i25 -; CHECK-NEXT: [[TY:%.*]] = trunc i32 [[Y]] to i25 -; CHECK-NEXT: [[CMP_LO:%.*]] = icmp eq i25 [[TX]], [[TY]] -; CHECK-NEXT: [[R:%.*]] = and i1 [[CMP_HI]], [[CMP_LO]] +; CHECK-NEXT: [[R:%.*]] = icmp eq i32 [[Y:%.*]], [[X:%.*]] ; CHECK-NEXT: ret i1 [[R]] ; %xor = xor i32 %y, %x @@ -1393,12 +1388,7 @@ define i1 @ne_optimized_highbits_cmp(i32 %x, i32 %y) { ; CHECK-LABEL: @ne_optimized_highbits_cmp( -; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]] -; CHECK-NEXT: [[CMP_HI:%.*]] = icmp ugt i32 [[XOR]], 16777215 -; CHECK-NEXT: [[TX:%.*]] = trunc i32 [[X]] to i24 -; CHECK-NEXT: [[TY:%.*]] = trunc i32 [[Y]] to i24 -; CHECK-NEXT: [[CMP_LO:%.*]] = icmp ne i24 [[TX]], [[TY]] -; CHECK-NEXT: [[R:%.*]] = or i1 [[CMP_HI]], [[CMP_LO]] +; CHECK-NEXT: [[R:%.*]] = icmp ne i32 [[Y:%.*]], [[X:%.*]] ; CHECK-NEXT: ret i1 [[R]] ; %xor = xor i32 %y, %x