diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h --- a/llvm/include/llvm/IR/ConstantRange.h +++ b/llvm/include/llvm/IR/ConstantRange.h @@ -124,6 +124,25 @@ static ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other); + /// Return true iff CR1 ult CR2 is equivalent to CR1 slt CR2. + /// Does not depend on strictness/direction of the predicate. + static bool + areInsensitiveToSignednessOfICmpPredicate(const ConstantRange &CR1, + const ConstantRange &CR2); + + /// Return true iff CR1 ult CR2 is equivalent to CR1 sgt CR2. + static bool + areInsensitiveToSignednessOfSwappedICmpPredicate(const ConstantRange &CR1, + const ConstantRange &CR2); + + /// If the comparison between constant ranges this and Other + /// is insensitive to the signedness of the comparison predicate, + /// return a predicate equivalent to \p Pred, with flipped signedness + /// (i.e. unsigned instead of signed or vice versa), and maybe swapped. + CmpInst::Predicate + getEquivalentPredWithFlippedSignedness(CmpInst::Predicate Pred, + const ConstantRange &Other) const; + /// Produce the largest range containing all X such that "X BinOp Y" is /// guaranteed not to wrap (overflow) for *all* Y in Other. However, there may /// be *some* Y in Other for which additional X not contained in the result diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -10645,29 +10645,41 @@ if (Depth > MaxSCEVOperationsImplicationDepth) return false; + // We want signed comparisons, so if it's and unsigned one, + // see if it can be just replaced with a signed one.. + if (ICmpInst::isUnsigned(Pred)) { + ICmpInst::Predicate SignedPred = + getSignedRange(LHS).getEquivalentPredWithFlippedSignedness( + Pred, getSignedRange(RHS)); + if (SignedPred != CmpInst::Predicate::BAD_ICMP_PREDICATE) + Pred = SignedPred; + } + // We only want to work with GT comparison so far. - if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) { + if (ICmpInst::isLT(Pred)) { Pred = CmpInst::getSwappedPredicate(Pred); std::swap(LHS, RHS); std::swap(FoundLHS, FoundRHS); } - // For unsigned, try to reduce it to corresponding signed comparison. - if (Pred == ICmpInst::ICMP_UGT) - // We can replace unsigned predicate with its signed counterpart if all - // involved values are non-negative. - // TODO: We could have better support for unsigned. - if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) { - // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing - // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us - // use this fact to prove that LHS and RHS are non-negative. - const SCEV *MinusOne = getMinusOne(LHS->getType()); - if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS, - FoundRHS) && - isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS, - FoundRHS)) - Pred = ICmpInst::ICMP_SGT; - } + // // For unsigned, try to reduce it to corresponding signed comparison. + // if (Pred == ICmpInst::ICMP_UGT) + // // We can replace unsigned predicate with its signed counterpart if all + // // involved values are non-negative. + // // TODO: We could have better support for unsigned. + // if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) { + // // Knowing that both FoundLHS and FoundRHS are non-negative, and + // knowing + // // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let + // us + // // use this fact to prove that LHS and RHS are non-negative. + // const SCEV *MinusOne = getMinusOne(LHS->getType()); + // if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS, + // FoundRHS) && + // isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS, + // FoundRHS)) + // Pred = ICmpInst::ICMP_SGT; + // } if (Pred != ICmpInst::ICMP_SGT) return false; diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp --- a/llvm/lib/IR/ConstantRange.cpp +++ b/llvm/lib/IR/ConstantRange.cpp @@ -147,6 +147,65 @@ return makeAllowedICmpRegion(Pred, C); } +bool ConstantRange::areInsensitiveToSignednessOfICmpPredicate( + const ConstantRange &CR1, const ConstantRange &CR2) { + if (CR1.isEmptySet() || CR2.isEmptySet()) + return true; + + if (CR1.isFullSet() || CR2.isFullSet()) + return false; + if (CR1.isWrappedSet() || CR2.isWrappedSet()) + return false; + if (CR1.isSignWrappedSet() || CR2.isSignWrappedSet()) + return false; + + if (CR1.contains(CR2) || CR2.contains(CR1)) + return true; + + if (CR1.isAllNonNegative() && CR2.isAllNonNegative()) + return true; + + if (CR1.isAllNegative() && CR2.isAllNegative()) + return !CR1.isUpperWrapped() || !CR2.isUpperWrapped(); + + return false; +}; + +bool ConstantRange::areInsensitiveToSignednessOfSwappedICmpPredicate( + const ConstantRange &CR1, const ConstantRange &CR2) { + if (CR1.isEmptySet() || CR2.isEmptySet()) + return true; + if (CR1.isFullSet() || CR2.isFullSet()) + return false; + + bool TrulyEquivalent = true; + + for (auto N1 : {CR1.getUnsignedMin(), CR1.getUnsignedMax(), + CR1.getSignedMin(), CR1.getSignedMax()}) + for (auto N2 : {CR2.getUnsignedMin(), CR2.getUnsignedMax(), + CR2.getSignedMin(), CR2.getSignedMax()}) + TrulyEquivalent &= N1.slt(N2) == N1.ugt(N2); + + return TrulyEquivalent; +}; + +CmpInst::Predicate ConstantRange::getEquivalentPredWithFlippedSignedness( + CmpInst::Predicate Pred, const ConstantRange &Other) const { + assert(CmpInst::isIntPredicate(Pred) && CmpInst::isRelational(Pred) && + "Only for relational integer predicates!"); + + CmpInst::Predicate FlippedSignednessPred = + CmpInst::getFlippedSignednessPredicate(Pred); + + if (areInsensitiveToSignednessOfICmpPredicate(*this, Other)) + return FlippedSignednessPred; + + if (areInsensitiveToSignednessOfSwappedICmpPredicate(*this, Other)) + return CmpInst::getSwappedPredicate(FlippedSignednessPred); + + return CmpInst::Predicate::BAD_ICMP_PREDICATE; +} + bool ConstantRange::getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const { bool Success = false; diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp --- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -66,6 +66,7 @@ STATISTIC(NumAShrs, "Number of ashr converted to lshr"); STATISTIC(NumSRems, "Number of srem converted to urem"); STATISTIC(NumSExt, "Number of sext converted to zext"); +STATISTIC(NumSICmps, "Number of signed icmp preds simplified to unsigned"); STATISTIC(NumAnd, "Number of ands removed"); STATISTIC(NumNW, "Number of no-wrap deductions"); STATISTIC(NumNSW, "Number of no-signed-wrap deductions"); @@ -294,11 +295,38 @@ return true; } +static bool processICmp(ICmpInst *Cmp, LazyValueInfo *LVI) { + if (Cmp->getType()->isVectorTy()) + return false; + + if (!Cmp->isSigned()) + return false; + + ICmpInst::Predicate UnsignedPred = + LVI->getConstantRange(Cmp->getOperand(0), Cmp) + .getEquivalentPredWithFlippedSignedness( + Cmp->getPredicate(), + LVI->getConstantRange(Cmp->getOperand(1), Cmp)); + + if (UnsignedPred == ICmpInst::Predicate::BAD_ICMP_PREDICATE) + return false; + + ++NumSICmps; + Instruction *UnsignedCmp = + ICmpInst::Create(Instruction::ICmp, UnsignedPred, Cmp->getOperand(0), + Cmp->getOperand(1), Cmp->getName(), Cmp); + UnsignedCmp->setDebugLoc(Cmp->getDebugLoc()); + Cmp->replaceAllUsesWith(UnsignedCmp); + Cmp->eraseFromParent(); + + return true; +} + /// See if LazyValueInfo's ability to exploit edge conditions or range /// information is sufficient to prove this comparison. Even for local /// conditions, this can sometimes prove conditions instcombine can't by /// exploiting range information. -static bool processCmp(CmpInst *Cmp, LazyValueInfo *LVI) { +static bool constantFoldNonLocalCmp(CmpInst *Cmp, LazyValueInfo *LVI) { Value *Op0 = Cmp->getOperand(0); auto *C = dyn_cast(Cmp->getOperand(1)); if (!C) @@ -317,6 +345,17 @@ return true; } +static bool processCmp(CmpInst *Cmp, LazyValueInfo *LVI) { + if (constantFoldNonLocalCmp(Cmp, LVI)) + return true; + + if (auto *ICmp = dyn_cast(Cmp)) + if (processICmp(ICmp, LVI)) + return true; + + return false; +} + /// Simplify a switch instruction by removing cases which can never fire. If the /// uselessness of a case could be determined locally then constant propagation /// would already have figured it out. Instead, walk the predecessors and diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll --- a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll +++ b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll @@ -64,8 +64,8 @@ ; CHECK: if.then: ; CHECK-NEXT: ret i32 1 ; CHECK: if.end: -; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i32 [[C]], 3 -; CHECK-NEXT: br i1 [[CMP1]], label [[IF_THEN2:%.*]], label [[IF_END8:%.*]] +; CHECK-NEXT: [[CMP11:%.*]] = icmp ult i32 [[C]], 3 +; CHECK-NEXT: br i1 [[CMP11]], label [[IF_THEN2:%.*]], label [[IF_END8:%.*]] ; CHECK: if.then2: ; CHECK-NEXT: br i1 true, label [[IF_THEN4:%.*]], label [[IF_END6:%.*]] ; CHECK: if.end6: diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp --- a/llvm/unittests/IR/ConstantRangeTest.cpp +++ b/llvm/unittests/IR/ConstantRangeTest.cpp @@ -6,8 +6,9 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/BitVector.h" #include "llvm/IR/ConstantRange.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/Sequence.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/Support/KnownBits.h" @@ -2418,4 +2419,113 @@ [](const APInt &N) { return ~N; }); } -} // anonymous namespace +bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred) { + assert(ICmpInst::isIntPredicate(Pred) && "Only for integer predicates!"); + switch (Pred) { + case ICmpInst::Predicate::ICMP_EQ: + return LHS.eq(RHS); + case ICmpInst::Predicate::ICMP_NE: + return LHS.ne(RHS); + case ICmpInst::Predicate::ICMP_UGT: + return LHS.ugt(RHS); + case ICmpInst::Predicate::ICMP_UGE: + return LHS.uge(RHS); + case ICmpInst::Predicate::ICMP_ULT: + return LHS.ult(RHS); + case ICmpInst::Predicate::ICMP_ULE: + return LHS.ule(RHS); + case ICmpInst::Predicate::ICMP_SGT: + return LHS.sgt(RHS); + case ICmpInst::Predicate::ICMP_SGE: + return LHS.sge(RHS); + case ICmpInst::Predicate::ICMP_SLT: + return LHS.slt(RHS); + case ICmpInst::Predicate::ICMP_SLE: + return LHS.sle(RHS); + default: + llvm_unreachable("Unexpected non-integer predicate."); + }; +} + +template +void testConstantRangeICmpPredEquivalence(ICmpInst::Predicate SrcPred, T Func) { + unsigned Bits = 4; + EnumerateTwoConstantRanges(Bits, [&](const ConstantRange &CR1, + const ConstantRange &CR2) { + ICmpInst::Predicate TgtPred; + bool ExpectedEquivalent; + std::tie(TgtPred, ExpectedEquivalent) = Func(CR1, CR2, SrcPred); + if (TgtPred == CmpInst::Predicate::BAD_ICMP_PREDICATE) + return; + bool TrulyEquivalent = true; + ForeachNumInConstantRange(CR1, [&](const APInt &N1) { + if (!TrulyEquivalent) + return; + ForeachNumInConstantRange(CR2, [&](const APInt &N2) { + if (!TrulyEquivalent) + return; + TrulyEquivalent &= compare(N1, N2, SrcPred) == compare(N1, N2, TgtPred); + }); + }); + ASSERT_EQ(TrulyEquivalent, ExpectedEquivalent); + }); +} + +TEST_F(ConstantRangeTest, areInsensitiveToSignednessOfICmpPredicate) { + for (auto Pred : seq((unsigned)ICmpInst::Predicate::FIRST_ICMP_PREDICATE, + ICmpInst::Predicate::LAST_ICMP_PREDICATE + 1)) { + if (ICmpInst::isEquality((ICmpInst::Predicate)Pred)) + continue; + ICmpInst::Predicate FlippedSignednessPred = + ICmpInst::getFlippedSignednessPredicate((ICmpInst::Predicate)Pred); + testConstantRangeICmpPredEquivalence( + (ICmpInst::Predicate)Pred, + [FlippedSignednessPred](const ConstantRange &CR1, + const ConstantRange &CR2, + ICmpInst::Predicate SrcPred) { + return std::make_pair( + FlippedSignednessPred, + ConstantRange::areInsensitiveToSignednessOfICmpPredicate(CR1, + CR2)); + }); + } +} + +TEST_F(ConstantRangeTest, areInsensitiveToSignednessOfSwappedICmpPredicate) { + for (auto Pred : seq((unsigned)ICmpInst::Predicate::FIRST_ICMP_PREDICATE, + ICmpInst::Predicate::LAST_ICMP_PREDICATE + 1)) { + if (ICmpInst::isEquality((ICmpInst::Predicate)Pred)) + continue; + ICmpInst::Predicate FlippedSignednessPred = + ICmpInst::getFlippedSignednessPredicate((ICmpInst::Predicate)Pred); + ICmpInst::Predicate SwappedFlippedSignednessPred = + ICmpInst::getSwappedPredicate(FlippedSignednessPred); + testConstantRangeICmpPredEquivalence( + (ICmpInst::Predicate)Pred, + [SwappedFlippedSignednessPred](const ConstantRange &CR1, + const ConstantRange &CR2, + ICmpInst::Predicate SrcPred) { + return std::make_pair( + SwappedFlippedSignednessPred, + ConstantRange::areInsensitiveToSignednessOfSwappedICmpPredicate( + CR1, CR2)); + }); + } +} + +TEST_F(ConstantRangeTest, getEquivalentPredWithFlippedSignedness) { + for (auto Pred : seq((unsigned)ICmpInst::Predicate::FIRST_ICMP_PREDICATE, + ICmpInst::Predicate::LAST_ICMP_PREDICATE + 1)) { + if (ICmpInst::isEquality((ICmpInst::Predicate)Pred)) + continue; + testConstantRangeICmpPredEquivalence( + (ICmpInst::Predicate)Pred, + [](const ConstantRange &CR1, const ConstantRange &CR2, + ICmpInst::Predicate SrcPred) { + return std::make_pair( + CR1.getEquivalentPredWithFlippedSignedness(SrcPred, CR2), true); + }); + } +} + +} // anonymous namespace