diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -234,11 +234,12 @@ /// LHS and RHS are the left hand side and the right hand side ICmps and PredL /// and PredR are their predicates, respectively. static std::optional> getMaskedTypeForICmpPair( - Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, ICmpInst *LHS, - ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) { + Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, + InstCombinerImpl::ICmpComponents LHS, InstCombinerImpl::ICmpComponents RHS, + ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) { // Don't allow pointers. Splat vectors are fine. - if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() || - !RHS->getOperand(0)->getType()->isIntOrIntVectorTy()) + if (!LHS.getOperand(0)->getType()->isIntOrIntVectorTy() || + !RHS.getOperand(0)->getType()->isIntOrIntVectorTy()) return std::nullopt; // Here comes the tricky part: @@ -247,8 +248,8 @@ // Now we must find those components L** and R**, that are equal, so // that we can extract the parameters A, B, C, D, and E for the canonical // above. - Value *L1 = LHS->getOperand(0); - Value *L2 = LHS->getOperand(1); + Value *L1 = LHS.getOperand(0); + Value *L2 = LHS.getOperand(1); Value *L11, *L12, *L21, *L22; // Check whether the icmp can be decomposed into a bit test. if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) { @@ -272,8 +273,8 @@ if (!ICmpInst::isEquality(PredL)) return std::nullopt; - Value *R1 = RHS->getOperand(0); - Value *R2 = RHS->getOperand(1); + Value *R1 = RHS.getOperand(0); + Value *R2 = RHS.getOperand(1); Value *R11, *R12; bool Ok = false; if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) { @@ -364,8 +365,9 @@ /// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8). /// Also used for logical and/or, must be poison safe. static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( - ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C, - Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, + InstCombinerImpl::ICmpComponents LHS, InstCombinerImpl::ICmpComponents RHS, + bool IsAnd, Value *A, Value *B, Value *C, Value *D, Value *E, + ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, InstCombiner::BuilderTy &Builder) { // We are given the canonical form: // (icmp ne (A & B), 0) & (icmp eq (A & D), E). @@ -456,7 +458,7 @@ // (icmp ne (A & 15), 0) & (icmp eq (A & 3), 0) -> no folding. if (ECst.isZero()) { if (IsSubSetOrEqual(BCst, DCst)) - return ConstantInt::get(LHS->getType(), !IsAnd); + return ConstantInt::get(LHS.getType(), !IsAnd); return nullptr; } @@ -467,18 +469,18 @@ // (icmp ne (A & 255), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8). // (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8). if (IsSuperSetOrEqual(BCst, DCst)) - return RHS; + return RHS.getOrCreateValue(Builder); // Otherwise, B is a subset of D. If B and E have a common bit set, // ie. (B & E) != 0, then LHS is subsumed by RHS. For example. // (icmp ne (A & 12), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8). assert(IsSubSetOrEqual(BCst, DCst) && "Precondition due to above code"); if ((*BCst & ECst) != 0) - return RHS; + return RHS.getOrCreateValue(Builder); // Otherwise, LHS and RHS contradict and the whole expression becomes false // (or true if negated.) For example, // (icmp ne (A & 7), 0) & (icmp eq (A & 15), 8) -> false. // (icmp ne (A & 6), 0) & (icmp eq (A & 15), 8) -> false. - return ConstantInt::get(LHS->getType(), !IsAnd); + return ConstantInt::get(LHS.getType(), !IsAnd); } /// Try to fold (icmp(A & B) ==/!= 0) &/| (icmp(A & D) ==/!= E) into a single @@ -486,9 +488,10 @@ /// aren't of the common mask pattern type. /// Also used for logical and/or, must be poison safe. static Value *foldLogOpOfMaskedICmpsAsymmetric( - ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C, - Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, - unsigned LHSMask, unsigned RHSMask, InstCombiner::BuilderTy &Builder) { + InstCombinerImpl::ICmpComponents LHS, InstCombinerImpl::ICmpComponents RHS, + bool IsAnd, Value *A, Value *B, Value *C, Value *D, Value *E, + ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, unsigned LHSMask, + unsigned RHSMask, InstCombiner::BuilderTy &Builder) { assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) && "Expected equality predicates for masked type of icmps."); // Handle Mask_NotAllZeros-BMask_Mixed cases. @@ -518,11 +521,12 @@ /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) /// into a single (icmp(A & X) ==/!= Y). -static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, - bool IsLogical, +static Value *foldLogOpOfMaskedICmps(InstCombinerImpl::ICmpComponents LHS, + InstCombinerImpl::ICmpComponents RHS, + bool IsAnd, bool IsLogical, InstCombiner::BuilderTy &Builder) { Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; - ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + ICmpInst::Predicate PredL = LHS.getPredicate(), PredR = RHS.getPredicate(); std::optional> MaskPair = getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR); if (!MaskPair) @@ -608,9 +612,9 @@ // the same as either B or D). APInt NewMask = *ConstB & *ConstD; if (NewMask == *ConstB) - return LHS; + return LHS.getOrCreateValue(Builder); else if (NewMask == *ConstD) - return RHS; + return RHS.getOrCreateValue(Builder); } if (Mask & AMask_NotAllOnes) { @@ -620,9 +624,9 @@ // the same as either B or D). APInt NewMask = *ConstB | *ConstD; if (NewMask == *ConstB) - return LHS; + return LHS.getOrCreateValue(Builder); else if (NewMask == *ConstD) - return RHS; + return RHS.getOrCreateValue(Builder); } if (Mask & (BMask_Mixed | BMask_NotMixed)) { @@ -658,7 +662,8 @@ const APInt ConstE = PredR != CC ? *ConstD ^ *OldConstE : *OldConstE; if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) - return IsNot ? nullptr : ConstantInt::get(LHS->getType(), !IsAnd); + return IsNot ? nullptr + : ConstantInt::get(LHS.getType(), !IsAnd); if (IsNot && !ConstB->isSubsetOf(*ConstD) && !ConstD->isSubsetOf(*ConstB)) return nullptr; @@ -688,33 +693,35 @@ /// Example: (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n /// If \p Inverted is true then the check is for the inverted range, e.g. /// (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n -Value *InstCombinerImpl::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, - bool Inverted) { +Value * +InstCombinerImpl::simplifyRangeCheck(InstCombinerImpl::ICmpComponents Cmp0, + InstCombinerImpl::ICmpComponents Cmp1, + bool Inverted) { // Check the lower range comparison, e.g. x >= 0 // InstCombine already ensured that if there is a constant it's on the RHS. - ConstantInt *RangeStart = dyn_cast(Cmp0->getOperand(1)); + ConstantInt *RangeStart = dyn_cast(Cmp0.getOperand(1)); if (!RangeStart) return nullptr; - ICmpInst::Predicate Pred0 = (Inverted ? Cmp0->getInversePredicate() : - Cmp0->getPredicate()); + ICmpInst::Predicate Pred0 = + (Inverted ? Cmp0.getInversePredicate() : Cmp0.getPredicate()); // Accept x > -1 or x >= 0 (after potentially inverting the predicate). if (!((Pred0 == ICmpInst::ICMP_SGT && RangeStart->isMinusOne()) || (Pred0 == ICmpInst::ICMP_SGE && RangeStart->isZero()))) return nullptr; - ICmpInst::Predicate Pred1 = (Inverted ? Cmp1->getInversePredicate() : - Cmp1->getPredicate()); + ICmpInst::Predicate Pred1 = + (Inverted ? Cmp1.getInversePredicate() : Cmp1.getPredicate()); - Value *Input = Cmp0->getOperand(0); + Value *Input = Cmp0.getOperand(0); Value *RangeEnd; - if (Cmp1->getOperand(0) == Input) { + if (Cmp1.getOperand(0) == Input) { // For the upper range compare we have: icmp x, n - RangeEnd = Cmp1->getOperand(1); - } else if (Cmp1->getOperand(1) == Input) { + RangeEnd = Cmp1.getOperand(1); + } else if (Cmp1.getOperand(1) == Input) { // For the upper range compare we have: icmp n, x - RangeEnd = Cmp1->getOperand(0); + RangeEnd = Cmp1.getOperand(0); Pred1 = ICmpInst::getSwappedPredicate(Pred1); } else { return nullptr; @@ -729,7 +736,7 @@ } // This simplification is only valid if the upper range is not negative. - KnownBits Known = computeKnownBits(RangeEnd, /*Depth=*/0, Cmp1); + KnownBits Known = computeKnownBits(RangeEnd, /*Depth=*/0, Cmp1.OrigCmp_); if (!Known.isNonNegative()) return nullptr; @@ -741,22 +748,20 @@ // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) -Value *InstCombinerImpl::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, - ICmpInst *RHS, - Instruction *CxtI, - bool IsAnd, - bool IsLogical) { +Value *InstCombinerImpl::foldAndOrOfICmpsOfAndWithPow2( + InstCombinerImpl::ICmpComponents LHS, InstCombinerImpl::ICmpComponents RHS, + Instruction *CxtI, bool IsAnd, bool IsLogical) { CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ; - if (LHS->getPredicate() != Pred || RHS->getPredicate() != Pred) + if (LHS.getPredicate() != Pred || RHS.getPredicate() != Pred) return nullptr; - if (!match(LHS->getOperand(1), m_Zero()) || - !match(RHS->getOperand(1), m_Zero())) + if (!match(LHS.getOperand(1), m_Zero()) || + !match(RHS.getOperand(1), m_Zero())) return nullptr; Value *L1, *L2, *R1, *R2; - if (match(LHS->getOperand(0), m_And(m_Value(L1), m_Value(L2))) && - match(RHS->getOperand(0), m_And(m_Value(R1), m_Value(R2)))) { + if (match(LHS.getOperand(0), m_And(m_Value(L1), m_Value(L2))) && + match(RHS.getOperand(0), m_And(m_Value(R1), m_Value(R2)))) { if (L1 == R2 || L2 == R2) std::swap(R1, R2); if (L2 == R1) @@ -809,19 +814,25 @@ /// masked bits are zero. /// So this should be transformed to: /// %r = icmp ult i32 %arg, 128 -static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1, +static Value *foldSignedTruncationCheck(InstCombinerImpl::ICmpComponents ICmp0, + InstCombinerImpl::ICmpComponents ICmp1, Instruction &CxtI, InstCombiner::BuilderTy &Builder) { assert(CxtI.getOpcode() == Instruction::And); // Match icmp ult (add %arg, C01), C1 (C1 == C01 << 1; powers of two) - auto tryToMatchSignedTruncationCheck = [](ICmpInst *ICmp, Value *&X, - APInt &SignBitMask) -> bool { + + auto tryToMatchSignedTruncationCheck = + [](InstCombinerImpl::ICmpComponents ICmp, Value *&X, + APInt &SignBitMask) -> bool { CmpInst::Predicate Pred; const APInt *I01, *I1; // powers of two; I1 == I01 << 1 - if (!(match(ICmp, - m_ICmp(Pred, m_Add(m_Value(X), m_Power2(I01)), m_Power2(I1))) && - Pred == ICmpInst::ICMP_ULT && I1->ugt(*I01) && I01->shl(1) == *I1)) + + if (!match(ICmp.getOperand(0), m_Add(m_Value(X), m_Power2(I01))) || + !match(ICmp.getOperand(1), m_Power2(I1))) + return false; + Pred = ICmp.getPredicate(); + if (Pred != ICmpInst::ICMP_ULT || I1->ule(*I01) || I01->shl(1) != *I1) return false; // Which bit is the new sign bit as per the 'signed truncation' pattern? SignBitMask = *I01; @@ -832,7 +843,7 @@ // We need to match this first, else we will mismatch commutative cases. Value *X1; APInt HighestBit; - ICmpInst *OtherICmp; + InstCombinerImpl::ICmpComponents OtherICmp; if (tryToMatchSignedTruncationCheck(ICmp1, X1, HighestBit)) OtherICmp = ICmp0; else if (tryToMatchSignedTruncationCheck(ICmp0, X1, HighestBit)) @@ -843,19 +854,20 @@ assert(HighestBit.isPowerOf2() && "expected to be power of two (non-zero)"); // Try to match/decompose into: icmp eq (X & Mask), 0 - auto tryToDecompose = [](ICmpInst *ICmp, Value *&X, + auto tryToDecompose = [](InstCombinerImpl::ICmpComponents ICmp, Value *&X, APInt &UnsetBitsMask) -> bool { - CmpInst::Predicate Pred = ICmp->getPredicate(); + CmpInst::Predicate Pred = ICmp.getPredicate(); // Can it be decomposed into icmp eq (X & Mask), 0 ? - if (llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1), - Pred, X, UnsetBitsMask, + if (llvm::decomposeBitTestICmp(ICmp.getOperand(0), ICmp.getOperand(1), Pred, + X, UnsetBitsMask, /*LookThroughTrunc=*/false) && Pred == ICmpInst::ICMP_EQ) return true; // Is it icmp eq (X & Mask), 0 already? const APInt *Mask; - if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) && - Pred == ICmpInst::ICMP_EQ) { + + if (match(ICmp.getOperand(0), m_And(m_Value(X), m_APInt(Mask))) && + match(ICmp.getOperand(1), m_Zero()) && Pred == ICmpInst::ICMP_EQ) { UnsetBitsMask = *Mask; return true; } @@ -906,16 +918,23 @@ /// Fold (icmp eq ctpop(X) 1) | (icmp eq X 0) into (icmp ult ctpop(X) 2) and /// fold (icmp ne ctpop(X) 1) & (icmp ne X 0) into (icmp ugt ctpop(X) 1). /// Also used for logical and/or, must be poison safe. -static Value *foldIsPowerOf2OrZero(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd, +static Value *foldIsPowerOf2OrZero(InstCombinerImpl::ICmpComponents Cmp0, + InstCombinerImpl::ICmpComponents Cmp1, + bool IsAnd, InstCombiner::BuilderTy &Builder) { CmpInst::Predicate Pred0, Pred1; Value *X; - if (!match(Cmp0, m_ICmp(Pred0, m_Intrinsic(m_Value(X)), - m_SpecificInt(1))) || - !match(Cmp1, m_ICmp(Pred1, m_Specific(X), m_ZeroInt()))) + + Pred0 = Cmp0.getPredicate(); + if (!match(Cmp0.getOperand(0), m_Intrinsic(m_Value(X))) || + !match(Cmp0.getOperand(1), m_SpecificInt(1))) + return nullptr; + + Pred1 = Cmp1.getPredicate(); + if (Cmp1.getOperand(0) != X || !match(Cmp1.getOperand(1), m_ZeroInt())) return nullptr; - Value *CtPop = Cmp0->getOperand(0); + Value *CtPop = Cmp0.getOperand(0); if (IsAnd && Pred0 == ICmpInst::ICMP_NE && Pred1 == ICmpInst::ICMP_NE) return Builder.CreateICmpUGT(CtPop, ConstantInt::get(CtPop->getType(), 1)); if (!IsAnd && Pred0 == ICmpInst::ICMP_EQ && Pred1 == ICmpInst::ICMP_EQ) @@ -926,32 +945,40 @@ /// Reduce a pair of compares that check if a value has exactly 1 bit set. /// Also used for logical and/or, must be poison safe. -static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd, +static Value *foldIsPowerOf2(InstCombinerImpl::ICmpComponents Cmp0, + InstCombinerImpl::ICmpComponents Cmp1, + bool JoinedByAnd, InstCombiner::BuilderTy &Builder) { // Handle 'and' / 'or' commutation: make the equality check the first operand. - if (JoinedByAnd && Cmp1->getPredicate() == ICmpInst::ICMP_NE) + if (JoinedByAnd && Cmp1.getPredicate() == ICmpInst::ICMP_NE) std::swap(Cmp0, Cmp1); - else if (!JoinedByAnd && Cmp1->getPredicate() == ICmpInst::ICMP_EQ) + else if (!JoinedByAnd && Cmp1.getPredicate() == ICmpInst::ICMP_EQ) std::swap(Cmp0, Cmp1); // (X != 0) && (ctpop(X) u< 2) --> ctpop(X) == 1 CmpInst::Predicate Pred0, Pred1; Value *X; - if (JoinedByAnd && match(Cmp0, m_ICmp(Pred0, m_Value(X), m_ZeroInt())) && - match(Cmp1, m_ICmp(Pred1, m_Intrinsic(m_Specific(X)), - m_SpecificInt(2))) && + + Pred0 = Cmp0.getPredicate(); + X = Cmp0.getOperand(0); + Pred1 = Cmp1.getPredicate(); + if (JoinedByAnd && match(Cmp0.getOperand(1), m_ZeroInt()) && + match(Cmp1.getOperand(0), m_Intrinsic(m_Specific(X))) && + match(Cmp1.getOperand(1), m_SpecificInt(2)) && Pred0 == ICmpInst::ICMP_NE && Pred1 == ICmpInst::ICMP_ULT) { - Value *CtPop = Cmp1->getOperand(0); + Value *CtPop = Cmp1.getOperand(0); return Builder.CreateICmpEQ(CtPop, ConstantInt::get(CtPop->getType(), 1)); } + // (X == 0) || (ctpop(X) u> 1) --> ctpop(X) != 1 - if (!JoinedByAnd && match(Cmp0, m_ICmp(Pred0, m_Value(X), m_ZeroInt())) && - match(Cmp1, m_ICmp(Pred1, m_Intrinsic(m_Specific(X)), - m_SpecificInt(1))) && + if (!JoinedByAnd && match(Cmp0.getOperand(1), m_ZeroInt()) && + match(Cmp1.getOperand(0), m_Intrinsic(m_Specific(X))) && + match(Cmp1.getOperand(1), m_SpecificInt(1)) && Pred0 == ICmpInst::ICMP_EQ && Pred1 == ICmpInst::ICMP_UGT) { - Value *CtPop = Cmp1->getOperand(0); + Value *CtPop = Cmp1.getOperand(0); return Builder.CreateICmpNE(CtPop, ConstantInt::get(CtPop->getType(), 1)); } + return nullptr; } @@ -1026,14 +1053,15 @@ /// 2, an earlier optimization converts the expression into (icmp X s> -1). /// Parameter P supports masking using undef/poison in either scalar or vector /// values. -static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1, +static Value *foldPowerOf2AndShiftedMask(InstCombinerImpl::ICmpComponents Cmp0, + InstCombinerImpl::ICmpComponents Cmp1, bool JoinedByAnd, InstCombiner::BuilderTy &Builder) { if (!JoinedByAnd) return nullptr; Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; - ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate(), - CmpPred1 = Cmp1->getPredicate(); + ICmpInst::Predicate CmpPred0 = Cmp0.getPredicate(), + CmpPred1 = Cmp1.getPredicate(); // Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u< // 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X & // SignMask) == 0). @@ -1059,14 +1087,17 @@ /// Commuted variants are assumed to be handled by calling this function again /// with the parameters swapped. -static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, - ICmpInst *UnsignedICmp, bool IsAnd, - const SimplifyQuery &Q, - InstCombiner::BuilderTy &Builder) { +static Value * +foldUnsignedUnderflowCheck(InstCombinerImpl::ICmpComponents ZeroICmp, + InstCombinerImpl::ICmpComponents UnsignedICmp, + bool IsAnd, const SimplifyQuery &Q, + InstCombiner::BuilderTy &Builder) { Value *ZeroCmpOp; ICmpInst::Predicate EqPred; - if (!match(ZeroICmp, m_ICmp(EqPred, m_Value(ZeroCmpOp), m_Zero())) || - !ICmpInst::isEquality(EqPred)) + + EqPred = ZeroICmp.getPredicate(); + ZeroCmpOp = ZeroICmp.getOperand(0); + if (!match(ZeroICmp.getOperand(1), m_Zero()) || !ICmpInst::isEquality(EqPred)) return nullptr; auto IsKnownNonZero = [&](Value *V) { @@ -1074,12 +1105,18 @@ }; ICmpInst::Predicate UnsignedPred; + Value *A = nullptr, *B; + + if (ZeroCmpOp == UnsignedICmp.getOperand(0)) { + A = UnsignedICmp.getOperand(1); + UnsignedPred = UnsignedICmp.getPredicate(); + } else if (ZeroCmpOp == UnsignedICmp.getOperand(1)) { + A = UnsignedICmp.getOperand(0); + UnsignedPred = UnsignedICmp.getSwappedPredicate(); + } - Value *A, *B; - if (match(UnsignedICmp, - m_c_ICmp(UnsignedPred, m_Specific(ZeroCmpOp), m_Value(A))) && - match(ZeroCmpOp, m_c_Add(m_Specific(A), m_Value(B))) && - (ZeroICmp->hasOneUse() || UnsignedICmp->hasOneUse())) { + if (A != nullptr && match(ZeroCmpOp, m_c_Add(m_Specific(A), m_Value(B))) && + (ZeroICmp.hasOneUse() || UnsignedICmp.hasOneUse())) { auto GetKnownNonZeroAndOther = [&](Value *&NonZero, Value *&Other) { if (!IsKnownNonZero(NonZero)) std::swap(NonZero, Other); @@ -1103,9 +1140,16 @@ if (!match(ZeroCmpOp, m_Sub(m_Value(Base), m_Value(Offset)))) return nullptr; - if (!match(UnsignedICmp, - m_c_ICmp(UnsignedPred, m_Specific(Base), m_Specific(Offset))) || - !ICmpInst::isUnsigned(UnsignedPred)) + if (UnsignedICmp.getOperand(0) == Base && + UnsignedICmp.getOperand(1) == Offset) + UnsignedPred = UnsignedICmp.getPredicate(); + else if (UnsignedICmp.getOperand(1) == Base && + UnsignedICmp.getOperand(0) == Offset) + UnsignedPred = UnsignedICmp.getSwappedPredicate(); + else + return nullptr; + + if (!ICmpInst::isUnsigned(UnsignedPred)) return nullptr; // Base >=/> Offset && (Base - Offset) != 0 <--> Base > Offset @@ -1173,19 +1217,20 @@ /// (icmp eq X0, Y0) & (icmp eq X1, Y1) -> icmp eq X01, Y01 /// (icmp ne X0, Y0) | (icmp ne X1, Y1) -> icmp ne X01, Y01 /// where X0, X1 and Y0, Y1 are adjacent parts extracted from an integer. -Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, +Value *InstCombinerImpl::foldEqOfParts(InstCombinerImpl::ICmpComponents Cmp0, + InstCombinerImpl::ICmpComponents Cmp1, bool IsAnd) { - if (!Cmp0->hasOneUse() || !Cmp1->hasOneUse()) + if (!Cmp0.hasOneUse() || !Cmp1.hasOneUse()) return nullptr; CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; - if (Cmp0->getPredicate() != Pred || Cmp1->getPredicate() != Pred) + if (Cmp0.getPredicate() != Pred || Cmp1.getPredicate() != Pred) return nullptr; - std::optional L0 = matchIntPart(Cmp0->getOperand(0)); - std::optional R0 = matchIntPart(Cmp0->getOperand(1)); - std::optional L1 = matchIntPart(Cmp1->getOperand(0)); - std::optional R1 = matchIntPart(Cmp1->getOperand(1)); + std::optional L0 = matchIntPart(Cmp0.getOperand(0)); + std::optional R0 = matchIntPart(Cmp0.getOperand(1)); + std::optional L1 = matchIntPart(Cmp1.getOperand(0)); + std::optional R1 = matchIntPart(Cmp1.getOperand(1)); if (!L0 || !R0 || !L1 || !R1) return nullptr; @@ -1219,18 +1264,22 @@ /// Reduce logic-of-compares with equality to a constant by substituting a /// common operand with the constant. Callers are expected to call this with /// Cmp0/Cmp1 switched to handle logic op commutativity. -static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, +static Value *foldAndOrOfICmpsWithConstEq(InstCombinerImpl::ICmpComponents Cmp0, + InstCombinerImpl::ICmpComponents Cmp1, bool IsAnd, bool IsLogical, InstCombiner::BuilderTy &Builder, const SimplifyQuery &Q) { // Match an equality compare with a non-poison constant as Cmp0. // Also, give up if the compare can be constant-folded to avoid looping. - ICmpInst::Predicate Pred0; - Value *X; + + ICmpInst::Predicate Pred0 = Cmp0.getPredicate(); + Value *X = Cmp0.getOperand(0); Constant *C; - if (!match(Cmp0, m_ICmp(Pred0, m_Value(X), m_Constant(C))) || + + if (!match(Cmp0.getOperand(1), m_Constant(C)) || !isGuaranteedNotToBeUndefOrPoison(C) || isa(X)) return nullptr; + if ((IsAnd && Pred0 != ICmpInst::ICMP_EQ) || (!IsAnd && Pred0 != ICmpInst::ICMP_NE)) return nullptr; @@ -1240,8 +1289,16 @@ // operand 0). Value *Y; ICmpInst::Predicate Pred1; - if (!match(Cmp1, m_c_ICmp(Pred1, m_Value(Y), m_Deferred(X)))) + + if (Cmp1.getOperand(1) == X) { + Y = Cmp1.getOperand(0); + Pred1 = Cmp1.getPredicate(); + } else if (Cmp1.getOperand(0) == X) { + Y = Cmp1.getOperand(1); + Pred1 = Cmp1.getSwappedPredicate(); + } else { return nullptr; + } // Replace variable with constant value equivalence to remove a variable use: // (X == C) && (Y Pred1 X) --> (X == C) && (Y Pred1 C) @@ -1252,14 +1309,15 @@ if (!SubstituteCmp) { // If we need to create a new instruction, require that the old compare can // be removed. - if (!Cmp1->hasOneUse()) + if (!Cmp1.hasOneUse()) return nullptr; SubstituteCmp = Builder.CreateICmp(Pred1, Y, C); } + Value *Cmp0V = Cmp0.getOrCreateValue(Builder); if (IsLogical) - return IsAnd ? Builder.CreateLogicalAnd(Cmp0, SubstituteCmp) - : Builder.CreateLogicalOr(Cmp0, SubstituteCmp); - return Builder.CreateBinOp(IsAnd ? Instruction::And : Instruction::Or, Cmp0, + return IsAnd ? Builder.CreateLogicalAnd(Cmp0V, SubstituteCmp) + : Builder.CreateLogicalOr(Cmp0V, SubstituteCmp); + return Builder.CreateBinOp(IsAnd ? Instruction::And : Instruction::Or, Cmp0V, SubstituteCmp); } @@ -1267,14 +1325,18 @@ /// or (icmp Pred1 V1, C1) | (icmp Pred2 V2, C2) /// into a single comparison using range-based reasoning. /// NOTE: This is also used for logical and/or, must be poison-safe! -Value *InstCombinerImpl::foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1, - ICmpInst *ICmp2, - bool IsAnd) { +Value *InstCombinerImpl::foldAndOrOfICmpsUsingRanges( + InstCombinerImpl::ICmpComponents ICmp1, + InstCombinerImpl::ICmpComponents ICmp2, bool IsAnd) { ICmpInst::Predicate Pred1, Pred2; Value *V1, *V2; const APInt *C1, *C2; - if (!match(ICmp1, m_ICmp(Pred1, m_Value(V1), m_APInt(C1))) || - !match(ICmp2, m_ICmp(Pred2, m_Value(V2), m_APInt(C2)))) + Pred1 = ICmp1.getPredicate(); + Pred2 = ICmp2.getPredicate(); + V1 = ICmp1.getOperand(0); + V2 = ICmp2.getOperand(0); + if (!match(ICmp1.getOperand(1), m_APInt(C1)) || + !match(ICmp2.getOperand(1), m_APInt(C2))) return nullptr; // Look through add of a constant offset on V1, V2, or both operands. This @@ -1305,7 +1367,7 @@ Value *NewV = V1; std::optional CR = CR1.exactUnionWith(CR2); if (!CR) { - if (!(ICmp1->hasOneUse() && ICmp2->hasOneUse()) || CR1.isWrappedSet() || + if (!(ICmp1.hasOneUse() && ICmp2.hasOneUse()) || CR1.isWrappedSet() || CR2.isWrappedSet()) return nullptr; @@ -1357,29 +1419,33 @@ /// and (fcmp ord x, 0), (fcmp u* x, inf) -> fcmp o* x, inf /// /// Clang emits this pattern for doing an isfinite check in __builtin_isnormal. -static Value *matchIsFiniteTest(InstCombiner::BuilderTy &Builder, FCmpInst *LHS, - FCmpInst *RHS) { - Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); - Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); - FCmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); +static Value *matchIsFiniteTest(InstCombiner::BuilderTy &Builder, + InstCombinerImpl::FCmpComponents LHS, + InstCombinerImpl::FCmpComponents RHS) { + Value *LHS0 = LHS.getOperand(0), *LHS1 = LHS.getOperand(1); + Value *RHS0 = RHS.getOperand(0), *RHS1 = RHS.getOperand(1); + FCmpInst::Predicate PredL = LHS.getPredicate(); + FCmpInst::Predicate PredR = RHS.getPredicate(); if (!matchIsNotNaN(PredL, LHS0, LHS1) || !matchUnorderedInfCompare(PredR, RHS0, RHS1)) return nullptr; IRBuilder<>::FastMathFlagGuard FMFG(Builder); - FastMathFlags FMF = LHS->getFastMathFlags(); - FMF &= RHS->getFastMathFlags(); + FastMathFlags FMF = LHS.OrigCmp_->getFastMathFlags(); + FMF &= RHS.OrigCmp_->getFastMathFlags(); Builder.setFastMathFlags(FMF); return Builder.CreateFCmp(FCmpInst::getOrderedPredicate(PredR), RHS0, RHS1); } -Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, - bool IsAnd, bool IsLogicalSelect) { - Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); - Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); - FCmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); +Value *InstCombinerImpl::foldLogicOfFCmps(FCmpComponents LHS, + FCmpComponents RHS, bool IsAnd, + bool IsLogicalSelect) { + Value *LHS0 = LHS.getOperand(0), *LHS1 = LHS.getOperand(1); + Value *RHS0 = RHS.getOperand(0), *RHS1 = RHS.getOperand(1); + FCmpInst::Predicate PredL = LHS.getPredicate(); + FCmpInst::Predicate PredR = RHS.getPredicate(); if (LHS0 == RHS1 && RHS0 == LHS1) { // Swap RHS operands to match LHS. @@ -1409,8 +1475,8 @@ // Intersect the fast math flags. // TODO: We can union the fast math flags unless this is a logical select. IRBuilder<>::FastMathFlagGuard FMFG(Builder); - FastMathFlags FMF = LHS->getFastMathFlags(); - FMF &= RHS->getFastMathFlags(); + FastMathFlags FMF = LHS.OrigCmp_->getFastMathFlags(); + FMF &= RHS.OrigCmp_->getFastMathFlags(); Builder.setFastMathFlags(FMF); return getFCmpValue(NewPred, LHS0, LHS1, Builder); @@ -1447,12 +1513,12 @@ // If we can represent a combined value test with one class call, we can // potentially eliminate 4-6 instructions. If we can represent a test with a // single fcmp with fneg and fabs, that's likely a better canonical form. - if (LHS->hasOneUse() && RHS->hasOneUse()) { + if (LHS.hasOneUse() && RHS.hasOneUse()) { auto [ClassValRHS, ClassMaskRHS] = - fcmpToClassTest(PredR, *RHS->getFunction(), RHS0, RHS1); + fcmpToClassTest(PredR, *RHS.OrigCmp_->getFunction(), RHS0, RHS1); if (ClassValRHS) { auto [ClassValLHS, ClassMaskLHS] = - fcmpToClassTest(PredL, *LHS->getFunction(), LHS0, LHS1); + fcmpToClassTest(PredL, *LHS.OrigCmp_->getFunction(), LHS0, LHS1); if (ClassValLHS == ClassValRHS) { unsigned CombinedMask = IsAnd ? (ClassMaskLHS & ClassMaskRHS) : (ClassMaskLHS | ClassMaskRHS); @@ -3007,23 +3073,23 @@ // (icmp eq X, C) | (icmp ult Other, (X - C)) -> (icmp ule Other, (X - (C + 1))) // (icmp ne X, C) & (icmp uge Other, (X - C)) -> (icmp ugt Other, (X - (C + 1))) -static Value *foldAndOrOfICmpEqConstantAndICmp(ICmpInst *LHS, ICmpInst *RHS, - bool IsAnd, bool IsLogical, - IRBuilderBase &Builder) { - Value *LHS0 = LHS->getOperand(0); - Value *RHS0 = RHS->getOperand(0); - Value *RHS1 = RHS->getOperand(1); +static Value *foldAndOrOfICmpEqConstantAndICmp( + InstCombinerImpl::ICmpComponents LHS, InstCombinerImpl::ICmpComponents RHS, + bool IsAnd, bool IsLogical, IRBuilderBase &Builder) { + Value *LHS0 = LHS.getOperand(0); + Value *RHS0 = RHS.getOperand(0); + Value *RHS1 = RHS.getOperand(1); ICmpInst::Predicate LPred = - IsAnd ? LHS->getInversePredicate() : LHS->getPredicate(); + IsAnd ? LHS.getInversePredicate() : LHS.getPredicate(); ICmpInst::Predicate RPred = - IsAnd ? RHS->getInversePredicate() : RHS->getPredicate(); + IsAnd ? RHS.getInversePredicate() : RHS.getPredicate(); const APInt *CInt; if (LPred != ICmpInst::ICMP_EQ || - !match(LHS->getOperand(1), m_APIntAllowUndef(CInt)) || + !match(LHS.getOperand(1), m_APIntAllowUndef(CInt)) || !LHS0->getType()->isIntOrIntVectorTy() || - !(LHS->hasOneUse() || RHS->hasOneUse())) + !(LHS.hasOneUse() || RHS.hasOneUse())) return nullptr; auto MatchRHSOp = [LHS0, CInt](const Value *RHSOp) { @@ -3052,9 +3118,9 @@ /// Fold (icmp)&(icmp) or (icmp)|(icmp) if possible. /// If IsLogical is true, then the and/or is in select form and the transform /// must be poison-safe. -Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, - Instruction &I, bool IsAnd, - bool IsLogical) { +Value *InstCombinerImpl::foldAndOrOfICmps(ICmpComponents LHS, + ICmpComponents RHS, Instruction &I, + bool IsAnd, bool IsLogical) { const SimplifyQuery Q = SQ.getWithInstruction(&I); // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) @@ -3063,9 +3129,9 @@ if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, &I, IsAnd, IsLogical)) return V; - ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); - Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); - Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1); + ICmpInst::Predicate PredL = LHS.getPredicate(), PredR = RHS.getPredicate(); + Value *LHS0 = LHS.getOperand(0), *RHS0 = RHS.getOperand(0); + Value *LHS1 = LHS.getOperand(1), *RHS1 = RHS.getOperand(1); const APInt *LHSC = nullptr, *RHSC = nullptr; match(LHS1, m_APInt(LHSC)); match(RHS1, m_APInt(RHSC)); @@ -3080,7 +3146,7 @@ if (LHS0 == RHS0 && LHS1 == RHS1) { unsigned Code = IsAnd ? getICmpCode(PredL) & getICmpCode(PredR) : getICmpCode(PredL) | getICmpCode(PredR); - bool IsSigned = LHS->isSigned() || RHS->isSigned(); + bool IsSigned = LHS.isSigned() || RHS.isSigned(); return getNewICmpValue(Code, IsSigned, LHS0, LHS1, Builder); } } @@ -3180,7 +3246,7 @@ // where CMAX is the all ones value for the truncated type, // iff the lower bits of C2 and CA are zero. if (PredL == (IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE) && - PredL == PredR && LHS->hasOneUse() && RHS->hasOneUse()) { + PredL == PredR && LHS.hasOneUse() && RHS.hasOneUse()) { Value *V; const APInt *AndC, *SmallC = nullptr, *BigC = nullptr; @@ -3219,7 +3285,7 @@ bool TrueIfSignedL, TrueIfSignedR; if (isSignBitCheck(PredL, *LHSC, TrueIfSignedL) && isSignBitCheck(PredR, *RHSC, TrueIfSignedR) && - (RHS->hasOneUse() || LHS->hasOneUse())) { + (RHS.hasOneUse() || LHS.hasOneUse())) { Value *X, *Y; if (IsAnd) { if ((TrueIfSignedL && !TrueIfSignedR && diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -108,7 +108,6 @@ Instruction *visitUDiv(BinaryOperator &I); Instruction *visitSDiv(BinaryOperator &I); Instruction *visitFDiv(BinaryOperator &I); - Value *simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, bool Inverted); Instruction *visitAnd(BinaryOperator &I); Instruction *visitOr(BinaryOperator &I); bool sinkNotIntoLogicalOp(Instruction &I); @@ -219,6 +218,56 @@ bool fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF, const Instruction *CtxI) const; + // Wrapper around icmp/fcmp instructions to be used in various folds to + // avoiding having to have to actual icmp/fcmp Value created. + // For use when the question of if we want to create a new cmp is a question + // of whether it will get folded. + template struct CmpComponents { + using PredType = typename CmpType::Predicate; + PredType Pred_; + Value *Op0_; + Value *Op1_; + CmpType *OrigCmp_; + + bool hasOneUse() const { return OrigCmp_->hasOneUse(); } + bool isSigned() const { return CmpType::isSigned(Pred_); } + + Type *getType() const { return OrigCmp_->getType(); } + + Value *getOperand(unsigned Idx) const { + assert(Idx <= 1 && "Out of bounds operand"); + return Idx == 0 ? Op0_ : Op1_; + } + + PredType getPredicate() const { return Pred_; } + + PredType getInversePredicate() const { + return CmpType::getInversePredicate(Pred_); + } + + PredType getSwappedPredicate() const { + return CmpType::getSwappedPredicate(Pred_); + } + + Value *getOrCreateValue(InstCombiner::BuilderTy &Builder) const { + if (Pred_ == OrigCmp_->getPredicate() && + Op0_ == OrigCmp_->getOperand(0) && Op1_ == OrigCmp_->getOperand(1)) + return OrigCmp_; + return Builder.CreateCmp(Pred_, Op0_, Op1_); + } + + CmpComponents() = default; + CmpComponents(CmpType *Cmp) + : Pred_(Cmp->getPredicate()), Op0_(Cmp->getOperand(0)), + Op1_(Cmp->getOperand(1)), OrigCmp_(Cmp) {} + }; + + using FCmpComponents = CmpComponents; + using ICmpComponents = CmpComponents; + + Value *simplifyRangeCheck(ICmpComponents Cmp0, ICmpComponents Cmp1, + bool Inverted); + private: bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI); bool isDesirableIntType(unsigned BitWidth) const; @@ -371,19 +420,19 @@ const CastInst *CI2); Value *simplifyIntToPtrRoundTripCast(Value *Val); - Value *foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &I, - bool IsAnd, bool IsLogical = false); + Value *foldAndOrOfICmps(ICmpComponents LHS, ICmpComponents RHS, + Instruction &I, bool IsAnd, bool IsLogical = false); Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Xor); - Value *foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd); + Value *foldEqOfParts(ICmpComponents Cmp0, ICmpComponents Cmp1, bool IsAnd); - Value *foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1, ICmpInst *ICmp2, + Value *foldAndOrOfICmpsUsingRanges(ICmpComponents ICmp1, ICmpComponents ICmp2, bool IsAnd); /// Optimize (fcmp)&(fcmp) or (fcmp)|(fcmp). /// NOTE: Unlike most of instcombine, this returns a Value which should /// already be inserted into the function. - Value *foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd, + Value *foldLogicOfFCmps(FCmpComponents LHS, FCmpComponents RHS, bool IsAnd, bool IsLogicalSelect = false); Instruction *foldLogicOfIsFPClass(BinaryOperator &Operator, Value *LHS, @@ -392,7 +441,7 @@ Instruction * canonicalizeConditionalNegationViaMathToSelect(BinaryOperator &i); - Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, + Value *foldAndOrOfICmpsOfAndWithPow2(ICmpComponents LHS, ICmpComponents RHS, Instruction *CxtI, bool IsAnd, bool IsLogical = false); Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D,