diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -925,6 +925,15 @@ /// Rotate right by rotateAmt. APInt rotr(const APInt &rotateAmt) const; + /// Evaluates LHS == ~RHS. + bool isInvertOf(const APInt &RHS) const { + if (isSingleWord()) { + uint64_t Mask = (WORDTYPE_MAX >> (APINT_BITS_PER_WORD - BitWidth)); + return U.VAL == (~RHS.U.VAL & Mask); + } + return isInvertOfSlowCase(RHS); + } + /// Concatenate the bits from "NewLSB" onto the bottom of *this. This is /// equivalent to: /// (this->zext(NewWidth) << NewLSB.getBitWidth()) | NewLSB.zext(NewWidth) @@ -1957,6 +1966,9 @@ /// out-of-line slow case for operator= void assignSlowCase(const APInt &RHS); + /// out-of-line slow case for operator== + bool isInvertOfSlowCase(const APInt &RHS) const LLVM_READONLY; + /// out-of-line slow case for operator== bool equalSlowCase(const APInt &RHS) const LLVM_READONLY; diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -49,7 +49,7 @@ /// Returns true if we know the value of all bits. bool isConstant() const { assert(!hasConflict() && "KnownBits conflict!"); - return Zero.countPopulation() + One.countPopulation() == getBitWidth(); + return Zero.isInvertOf(One); } /// Returns the value when all bits have a known value. This just returns One diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -2395,7 +2395,7 @@ const APInt *C1, *C2; if (match(Op0, m_And(m_Value(A), m_APInt(C1))) && match(Op1, m_And(m_Value(B), m_APInt(C2)))) { - if (*C1 == ~*C2) { + if (C1->isInvertOf(*C2)) { // (A & C1)|(B & C2) // If we have: ((V + N) & C1) | (V & C2) // .. and C2 = ~C1 and C2 is 0+1+ and (N & C2) == 0 @@ -4207,13 +4207,13 @@ // (X & Y) == 0 ? X & ~Y : X --> X // (X & Y) != 0 ? X & ~Y : X --> X & ~Y if (FalseVal == X && match(TrueVal, m_And(m_Specific(X), m_APInt(C))) && - *Y == ~*C) + Y->isInvertOf(*C)) return TrueWhenUnset ? FalseVal : TrueVal; // (X & Y) == 0 ? X : X & ~Y --> X & ~Y // (X & Y) != 0 ? X : X & ~Y --> X if (TrueVal == X && match(FalseVal, m_And(m_Specific(X), m_APInt(C))) && - *Y == ~*C) + Y->isInvertOf(*C)) return TrueWhenUnset ? FalseVal : TrueVal; if (Y->isPowerOf2()) { diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10670,7 +10670,7 @@ N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) { auto *C = dyn_cast(N2.getOperand(1)); auto *NotC = dyn_cast(Cond1); - if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) { + if (C && NotC && C->getAPIntValue().isInvertOf(NotC->getAPIntValue())) { // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) --> // uaddo Cond0, C; select uaddo.1, -1, uaddo.0 // @@ -11246,7 +11246,7 @@ // canonicalization. // x >= ~C ? x+C : ~0 --> uaddsat x, C auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) { - return Cond->getAPIntValue() == ~Op->getAPIntValue(); + return Cond->getAPIntValue().isInvertOf(Op->getAPIntValue()); }; if (SatCC == ISD::SETULE && ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUADDSAT)) @@ -24637,7 +24637,7 @@ // Fold select_cc setgt X, -1, C, ~C -> xor (ashr X, BW-1), C // Fold select_cc setlt X, 0, C, ~C -> xor (ashr X, BW-1), ~C if (!NotExtCompare && N1C && N2C && N3C && - N2C->getAPIntValue() == ~N3C->getAPIntValue() && + N2C->getAPIntValue().isInvertOf(N3C->getAPIntValue()) && ((N1C->isAllOnes() && CC == ISD::SETGT) || (N1C->isZero() && CC == ISD::SETLT)) && !TLI.shouldAvoidTransformToShift(VT, CmpOpVT.getScalarSizeInBits() - 1)) { diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -1372,7 +1372,7 @@ // the xor. For example, for a 32-bit X: // and (xor (srl X, 31), -1), 1 --> xor (srl X, 31), 1 if (isBitwiseNot(Op0) && Op0.hasOneUse() && - LHSKnown.One == ~RHSC->getAPIntValue()) { + LHSKnown.One.isInvertOf(RHSC->getAPIntValue())) { SDValue Xor = TLO.DAG.getNode(ISD::XOR, dl, VT, Op0.getOperand(0), Op1); return TLO.CombineTo(Op, Xor); } diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp --- a/llvm/lib/Support/APInt.cpp +++ b/llvm/lib/Support/APInt.cpp @@ -266,6 +266,17 @@ return clearUnusedBits(); } +bool APInt::isInvertOfSlowCase(const APInt &RHS) const { + WordType *lhs = U.pVal, *rhs = RHS.U.pVal; + unsigned WordBits = ((BitWidth - 1) % APINT_BITS_PER_WORD) + 1; + size_t e = getNumWords() - 1; + for (size_t i = 0; i != e; ++i) + if (lhs[i] != ~rhs[i]) + return false; + uint64_t Mask = WORDTYPE_MAX >> (APINT_BITS_PER_WORD - WordBits); + return lhs[e] == (~rhs[e] & Mask); +} + bool APInt::equalSlowCase(const APInt &RHS) const { return std::equal(U.pVal, U.pVal + getNumWords(), RHS.U.pVal); } diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -2923,7 +2923,7 @@ if (And0.hasOneUse() && And1.hasOneUse() && isOpcWithIntImmediate(And0.getNode(), ISD::AND, Mask0Imm) && isOpcWithIntImmediate(And1.getNode(), ISD::AND, Mask1Imm) && - APInt(BitWidth, Mask0Imm) == ~APInt(BitWidth, Mask1Imm) && + APInt(BitWidth, Mask0Imm).isInvertOf(APInt(BitWidth, Mask1Imm)) && (isShiftedMask(Mask0Imm, VT) || isShiftedMask(Mask1Imm, VT))) { // ORR is commutative, so canonicalize to the form 'or (and X, Mask0Imm), diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -8725,7 +8725,7 @@ // If both TVal and FVal are constants, see if FVal is the // inverse/negation/increment of TVal and generate a CSINV/CSNEG/CSINC // instead of a CSEL in that case. - if (TrueVal == ~FalseVal) { + if (TrueVal.isInvertOf(FalseVal)) { Opcode = AArch64ISD::CSINV; } else if (FalseVal > std::numeric_limits::min() && TrueVal == -FalseVal) { diff --git a/llvm/lib/Target/Mips/MipsSEISelLowering.cpp b/llvm/lib/Target/Mips/MipsSEISelLowering.cpp --- a/llvm/lib/Target/Mips/MipsSEISelLowering.cpp +++ b/llvm/lib/Target/Mips/MipsSEISelLowering.cpp @@ -621,10 +621,12 @@ IfSet = Op0Op1; if (isVSplat(Op1Op0, InvMask, IsLittleEndian) && - Mask.getBitWidth() == InvMask.getBitWidth() && Mask == ~InvMask) + Mask.getBitWidth() == InvMask.getBitWidth() && + Mask.isInvertOf(InvMask)) IfClr = Op1Op1; else if (isVSplat(Op1Op1, InvMask, IsLittleEndian) && - Mask.getBitWidth() == InvMask.getBitWidth() && Mask == ~InvMask) + Mask.getBitWidth() == InvMask.getBitWidth() && + Mask.isInvertOf(InvMask)) IfClr = Op1Op0; IsConstantMask = true; @@ -638,10 +640,12 @@ IfSet = Op0Op0; if (isVSplat(Op1Op0, InvMask, IsLittleEndian) && - Mask.getBitWidth() == InvMask.getBitWidth() && Mask == ~InvMask) + Mask.getBitWidth() == InvMask.getBitWidth() && + Mask.isInvertOf(InvMask)) IfClr = Op1Op1; else if (isVSplat(Op1Op1, InvMask, IsLittleEndian) && - Mask.getBitWidth() == InvMask.getBitWidth() && Mask == ~InvMask) + Mask.getBitWidth() == InvMask.getBitWidth() && + Mask.isInvertOf(InvMask)) IfClr = Op1Op0; IsConstantMask = true; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -773,7 +773,7 @@ if (match(X, m_Xor(m_Value(Y), m_APInt(C1)))) { // X = XOR(Y, C1), Y = OR(Z, C2), C2 = NOT(C1) ==> X == NOT(AND(Z, C1)) // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, AND(Z, C1)) - if (match(Y, m_Or(m_Value(Z), m_APInt(C2))) && (*C2 == ~(*C1))) { + if (match(Y, m_Or(m_Value(Z), m_APInt(C2))) && (C2->isInvertOf(*C1))) { Value *NewAnd = Builder.CreateAnd(Z, *C1); return Builder.CreateSub(RHS, NewAnd, "sub"); } else if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && (*C1 == *C2)) { @@ -1123,7 +1123,7 @@ // add (umin X, ~C), C --> uaddsat X, C const APInt *C, *NotC; if (match(&I, m_Add(m_UMin(m_Value(X), m_APInt(NotC)), m_APInt(C))) && - *C == ~*NotC) + C->isInvertOf(*NotC)) return CallInst::Create(getUAddSat(), { X, ConstantInt::get(Ty, *C) }); return nullptr; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2891,7 +2891,7 @@ const APInt *C0, *C1; if (match(C, m_APInt(C0)) && match(D, m_APInt(C1))) { Value *X; - if (*C0 == ~*C1) { + if (C0->isInvertOf(*C1)) { // ((X | B) & MaskC) | (B & ~MaskC) -> (X & MaskC) | B if (match(A, m_c_Or(m_Value(X), m_Specific(B)))) return BinaryOperator::CreateOr(Builder.CreateAnd(X, *C0), B); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1661,7 +1661,7 @@ // Mask constant magic can eliminate an 'xor' with unsigned compares. if (Pred == ICmpInst::ICMP_UGT) { // (xor X, ~C) >u C --> X isInvertOf(C) && (C + 1).isPowerOf2()) return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); // (xor X, C) >u C --> X >u C (when C+1 is a power of 2) if (*XorC == C && (C + 1).isPowerOf2()) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -700,7 +700,8 @@ // Cond ? (X & ~C) : (X | C) --> (X & ~C) | (Cond ? 0 : C) if (match(T, m_And(m_Value(X), m_APInt(NotC))) && - match(F, m_OneUse(m_Or(m_Specific(X), m_APInt(C)))) && *NotC == ~(*C)) { + match(F, m_OneUse(m_Or(m_Specific(X), m_APInt(C)))) && + NotC->isInvertOf(*C)) { Constant *Zero = ConstantInt::getNullValue(Ty); Constant *OrC = ConstantInt::get(Ty, *C); Value *NewSel = Builder.CreateSelect(Cond, Zero, OrC, "masksel", &Sel); @@ -709,7 +710,8 @@ // Cond ? (X | C) : (X & ~C) --> (X & ~C) | (Cond ? C : 0) if (match(F, m_And(m_Value(X), m_APInt(NotC))) && - match(T, m_OneUse(m_Or(m_Specific(X), m_APInt(C)))) && *NotC == ~(*C)) { + match(T, m_OneUse(m_Or(m_Specific(X), m_APInt(C)))) && + NotC->isInvertOf(*C)) { Constant *Zero = ConstantInt::getNullValue(Ty); Constant *OrC = ConstantInt::get(Ty, *C); Value *NewSel = Builder.CreateSelect(Cond, OrC, Zero, "masksel", &Sel); @@ -842,7 +844,8 @@ const APInt *C, *CmpC; if (Pred == ICmpInst::ICMP_ULT && match(TVal, m_Add(m_Value(X), m_APInt(C))) && X == Cmp0 && - match(FVal, m_AllOnes()) && match(Cmp1, m_APInt(CmpC)) && *CmpC == ~*C) { + match(FVal, m_AllOnes()) && match(Cmp1, m_APInt(CmpC)) && + CmpC->isInvertOf(*C)) { // (X u< ~C) ? (X + C) : -1 --> uadd.sat(X, C) return Builder.CreateBinaryIntrinsic( Intrinsic::uadd_sat, X, ConstantInt::get(X->getType(), *C)); diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -2607,6 +2607,17 @@ EXPECT_EQ(0, neg_one.shl(128)); } +TEST(APIntTest, isInvertOf) { + for (auto BitWidth : {8, 17, 32, 53, 64, 122, 128}) { + APInt in_1(BitWidth, 4); + APInt in_2(BitWidth, 53); + EXPECT_TRUE(in_1.isInvertOf(~in_1)); + EXPECT_TRUE((~in_1).isInvertOf(in_1)); + EXPECT_FALSE(in_1.isInvertOf(in_1)); + EXPECT_FALSE(in_1.isInvertOf(in_2)); + } +} + TEST(APIntTest, isSubsetOf) { APInt i32_1(32, 1); APInt i32_2(32, 2);