Index: include/llvm/Analysis/ValueLattice.h =================================================================== --- include/llvm/Analysis/ValueLattice.h +++ include/llvm/Analysis/ValueLattice.h @@ -286,24 +286,32 @@ return cast(getConstant()); } - bool satisfiesPredicate(CmpInst::Predicate Pred, - const ValueLatticeElement &Other) const { - // TODO: share with LVI getPredicateResult. - + /// Compares this symbolic value with Other using Pred and returns either + /// true, false or undef constants, or nullptr if the comparison cannot be + /// evaluated. + Constant *getCompare(CmpInst::Predicate Pred, Type *Ty, + const ValueLatticeElement &Other) const { if (isUndefined() || Other.isUndefined()) - return true; + return UndefValue::get(Ty); - if (isConstant() && Other.isConstant() && Pred == CmpInst::FCMP_OEQ) - return getConstant() == Other.getConstant(); + if (isConstant() && Other.isConstant()) + return ConstantExpr::getCompare(Pred, getConstant(), Other.getConstant()); // Integer constants are represented as ConstantRanges with single // elements. if (!isConstantRange() || !Other.isConstantRange()) - return false; + return nullptr; const auto &CR = getConstantRange(); const auto &OtherCR = Other.getConstantRange(); - return ConstantRange::makeSatisfyingICmpRegion(Pred, OtherCR).contains(CR); + if (ConstantRange::makeSatisfyingICmpRegion(Pred, OtherCR).contains(CR)) + return ConstantInt::getTrue(Ty); + if (ConstantRange::makeSatisfyingICmpRegion( + CmpInst::getInversePredicate(Pred), OtherCR) + .contains(CR)) + return ConstantInt::getFalse(Ty); + + return nullptr; } }; Index: lib/Transforms/Scalar/SCCP.cpp =================================================================== --- lib/Transforms/Scalar/SCCP.cpp +++ lib/Transforms/Scalar/SCCP.cpp @@ -1622,12 +1622,7 @@ ValueLatticeElement A = getIcmpLatticeValue(Icmp->getOperand(0)); ValueLatticeElement B = getIcmpLatticeValue(Icmp->getOperand(1)); - Constant *C = nullptr; - if (A.satisfiesPredicate(Icmp->getPredicate(), B)) - C = ConstantInt::getTrue(Icmp->getType()); - else if (A.satisfiesPredicate(Icmp->getInversePredicate(), B)) - C = ConstantInt::getFalse(Icmp->getType()); - + Constant *C = A.getCompare(Icmp->getPredicate(), Icmp->getType(), B); if (C) { Icmp->replaceAllUsesWith(C); DEBUG(dbgs() << "Replacing " << *Icmp << " with " << *C Index: test/Transforms/SCCP/ip-constant-ranges.ll =================================================================== --- test/Transforms/SCCP/ip-constant-ranges.ll +++ test/Transforms/SCCP/ip-constant-ranges.ll @@ -141,3 +141,29 @@ %r = fmul double %v, %v ret double %r } + +; Constant range for %x is [47, 302) +; CHECK-LABEL: @f5 +; CHECK-NEXT: entry: +; CHECK-NEXT: %res1 = select i1 undef, i32 1, i32 2 +; CHECK-NEXT: %res2 = select i1 undef, i32 3, i32 4 +; CHECK-NEXT: %res = add i32 %res1, %res2 +; CHECK-NEXT: ret i32 %res +define internal i32 @f5(i32 %x) { +entry: + %cmp = icmp sgt i32 %x, undef + %cmp2 = icmp ne i32 undef, %x + %res1 = select i1 %cmp, i32 1, i32 2 + %res2 = select i1 %cmp2, i32 3, i32 4 + + %res = add i32 %res1, %res2 + ret i32 %res +} + +define i32 @caller4() { +entry: + %call1 = tail call i32 @f5(i32 47) + %call2 = tail call i32 @f5(i32 301) + %res = add nsw i32 %call1, %call2 + ret i32 %res +} Index: unittests/Analysis/ValueLatticeTest.cpp =================================================================== --- unittests/Analysis/ValueLatticeTest.cpp +++ unittests/Analysis/ValueLatticeTest.cpp @@ -76,72 +76,98 @@ EXPECT_TRUE(LV1.isOverdefined()); } -TEST_F(ValueLatticeTest, satisfiesPredicateIntegers) { - auto I32Ty = IntegerType::get(Context, 32); +TEST_F(ValueLatticeTest, getCompareIntegers) { + auto *I32Ty = IntegerType::get(Context, 32); + auto *I1Ty = IntegerType::get(Context, 1); auto *C1 = ConstantInt::get(I32Ty, 1); auto LV1 = ValueLatticeElement::get(C1); - // Check satisfiesPredicate for equal integer constants. - EXPECT_TRUE(LV1.satisfiesPredicate(CmpInst::ICMP_EQ, LV1)); - EXPECT_TRUE(LV1.satisfiesPredicate(CmpInst::ICMP_SGE, LV1)); - EXPECT_TRUE(LV1.satisfiesPredicate(CmpInst::ICMP_SLE, LV1)); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::ICMP_NE, LV1)); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::ICMP_SLT, LV1)); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::ICMP_SGT, LV1)); + // Check getCompare for equal integer constants. + EXPECT_TRUE(LV1.getCompare(CmpInst::ICMP_EQ, I1Ty, LV1)->isOneValue()); + EXPECT_TRUE(LV1.getCompare(CmpInst::ICMP_SGE, I1Ty, LV1)->isOneValue()); + EXPECT_TRUE(LV1.getCompare(CmpInst::ICMP_SLE, I1Ty, LV1)->isOneValue()); + EXPECT_TRUE(LV1.getCompare(CmpInst::ICMP_NE, I1Ty, LV1)->isZeroValue()); + EXPECT_TRUE(LV1.getCompare(CmpInst::ICMP_SLT, I1Ty, LV1)->isZeroValue()); + EXPECT_TRUE(LV1.getCompare(CmpInst::ICMP_SGT, I1Ty, LV1)->isZeroValue()); auto LV2 = ValueLatticeElement::getRange({APInt(32, 10, true), APInt(32, 20, true)}); - // Check satisfiesPredicate with distinct integer ranges. - EXPECT_TRUE(LV1.satisfiesPredicate(CmpInst::ICMP_SLT, LV2)); - EXPECT_TRUE(LV1.satisfiesPredicate(CmpInst::ICMP_SLE, LV2)); - EXPECT_TRUE(LV1.satisfiesPredicate(CmpInst::ICMP_NE, LV2)); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::ICMP_EQ, LV2)); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::ICMP_SGE, LV2)); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::ICMP_SGT, LV2)); + // Check getCompare with distinct integer ranges. + EXPECT_TRUE(LV1.getCompare(CmpInst::ICMP_SLT, I1Ty, LV2)->isOneValue()); + EXPECT_TRUE(LV1.getCompare(CmpInst::ICMP_SLE, I1Ty, LV2)->isOneValue()); + EXPECT_TRUE(LV1.getCompare(CmpInst::ICMP_NE, I1Ty, LV2)->isOneValue()); + EXPECT_TRUE(LV1.getCompare(CmpInst::ICMP_EQ, I1Ty, LV2)->isZeroValue()); + EXPECT_TRUE(LV1.getCompare(CmpInst::ICMP_SGE, I1Ty, LV2)->isZeroValue()); + EXPECT_TRUE(LV1.getCompare(CmpInst::ICMP_SGT, I1Ty, LV2)->isZeroValue()); auto LV3 = ValueLatticeElement::getRange({APInt(32, 15, true), APInt(32, 19, true)}); - // Check satisfiesPredicate with a subset integer ranges. - EXPECT_FALSE(LV2.satisfiesPredicate(CmpInst::ICMP_SLT, LV3)); - EXPECT_FALSE(LV2.satisfiesPredicate(CmpInst::ICMP_SLE, LV3)); - EXPECT_FALSE(LV2.satisfiesPredicate(CmpInst::ICMP_NE, LV3)); - EXPECT_FALSE(LV2.satisfiesPredicate(CmpInst::ICMP_EQ, LV3)); - EXPECT_FALSE(LV2.satisfiesPredicate(CmpInst::ICMP_SGE, LV3)); - EXPECT_FALSE(LV2.satisfiesPredicate(CmpInst::ICMP_SGT, LV3)); + // Check getCompare with a subset integer ranges. + EXPECT_EQ(LV2.getCompare(CmpInst::ICMP_SLT, I1Ty, LV3), nullptr); + EXPECT_EQ(LV2.getCompare(CmpInst::ICMP_SLE, I1Ty, LV3), nullptr); + EXPECT_EQ(LV2.getCompare(CmpInst::ICMP_NE, I1Ty, LV3), nullptr); + EXPECT_EQ(LV2.getCompare(CmpInst::ICMP_EQ, I1Ty, LV3), nullptr); + EXPECT_EQ(LV2.getCompare(CmpInst::ICMP_SGE, I1Ty, LV3), nullptr); + EXPECT_EQ(LV2.getCompare(CmpInst::ICMP_SGT, I1Ty, LV3), nullptr); auto LV4 = ValueLatticeElement::getRange({APInt(32, 15, true), APInt(32, 25, true)}); - // Check satisfiesPredicate with overlapping integer ranges. - EXPECT_FALSE(LV3.satisfiesPredicate(CmpInst::ICMP_SLT, LV4)); - EXPECT_FALSE(LV3.satisfiesPredicate(CmpInst::ICMP_SLE, LV4)); - EXPECT_FALSE(LV3.satisfiesPredicate(CmpInst::ICMP_NE, LV4)); - EXPECT_FALSE(LV3.satisfiesPredicate(CmpInst::ICMP_EQ, LV4)); - EXPECT_FALSE(LV3.satisfiesPredicate(CmpInst::ICMP_SGE, LV4)); - EXPECT_FALSE(LV3.satisfiesPredicate(CmpInst::ICMP_SGT, LV4)); + // Check getCompare with overlapping integer ranges. + EXPECT_EQ(LV3.getCompare(CmpInst::ICMP_SLT, I1Ty, LV4), nullptr); + EXPECT_EQ(LV3.getCompare(CmpInst::ICMP_SLE, I1Ty, LV4), nullptr); + EXPECT_EQ(LV3.getCompare(CmpInst::ICMP_NE, I1Ty, LV4), nullptr); + EXPECT_EQ(LV3.getCompare(CmpInst::ICMP_EQ, I1Ty, LV4), nullptr); + EXPECT_EQ(LV3.getCompare(CmpInst::ICMP_SGE, I1Ty, LV4), nullptr); + EXPECT_EQ(LV3.getCompare(CmpInst::ICMP_SGT, I1Ty, LV4), nullptr); } -TEST_F(ValueLatticeTest, satisfiesPredicateFloat) { - auto FloatTy = IntegerType::getFloatTy(Context); +TEST_F(ValueLatticeTest, getCompareFloat) { + auto *FloatTy = IntegerType::getFloatTy(Context); + auto *I1Ty = IntegerType::get(Context, 1); auto *C1 = ConstantFP::get(FloatTy, 1.0); auto LV1 = ValueLatticeElement::get(C1); auto LV2 = ValueLatticeElement::get(C1); - // Check satisfiesPredicate for equal floating point constants. - EXPECT_TRUE(LV1.satisfiesPredicate(CmpInst::FCMP_OEQ, LV2)); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::FCMP_OGE, LV2)); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::FCMP_OLE, LV2)); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::FCMP_ONE, LV2)); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::FCMP_OLT, LV2)); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::FCMP_OGT, LV2)); + // Check getCompare for equal floating point constants. + EXPECT_TRUE(LV1.getCompare(CmpInst::FCMP_OEQ, I1Ty, LV2)->isOneValue()); + EXPECT_TRUE(LV1.getCompare(CmpInst::FCMP_OGE, I1Ty, LV2)->isOneValue()); + EXPECT_TRUE(LV1.getCompare(CmpInst::FCMP_OLE, I1Ty, LV2)->isOneValue()); + EXPECT_TRUE(LV1.getCompare(CmpInst::FCMP_ONE, I1Ty, LV2)->isZeroValue()); + EXPECT_TRUE(LV1.getCompare(CmpInst::FCMP_OLT, I1Ty, LV2)->isZeroValue()); + EXPECT_TRUE(LV1.getCompare(CmpInst::FCMP_OGT, I1Ty, LV2)->isZeroValue()); LV1.mergeIn(ValueLatticeElement::get(ConstantFP::get(FloatTy, 2.2)), M.getDataLayout()); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::FCMP_OEQ, LV2)); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::FCMP_OGE, LV2)); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::FCMP_OLE, LV2)); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::FCMP_ONE, LV2)); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::FCMP_OLT, LV2)); - EXPECT_FALSE(LV1.satisfiesPredicate(CmpInst::FCMP_OGT, LV2)); + EXPECT_EQ(LV1.getCompare(CmpInst::FCMP_OEQ, I1Ty, LV2), nullptr); + EXPECT_EQ(LV1.getCompare(CmpInst::FCMP_OGE, I1Ty, LV2), nullptr); + EXPECT_EQ(LV1.getCompare(CmpInst::FCMP_OLE, I1Ty, LV2), nullptr); + EXPECT_EQ(LV1.getCompare(CmpInst::FCMP_ONE, I1Ty, LV2), nullptr); + EXPECT_EQ(LV1.getCompare(CmpInst::FCMP_OLT, I1Ty, LV2), nullptr); + EXPECT_EQ(LV1.getCompare(CmpInst::FCMP_OGT, I1Ty, LV2), nullptr); +} + +TEST_F(ValueLatticeTest, getCompareUndef) { + auto *I32Ty = IntegerType::get(Context, 32); + auto *I1Ty = IntegerType::get(Context, 1); + + auto LV1 = ValueLatticeElement::get(UndefValue::get(I32Ty)); + auto LV2 = + ValueLatticeElement::getRange({APInt(32, 10, true), APInt(32, 20, true)}); + EXPECT_TRUE(isa(LV1.getCompare(CmpInst::ICMP_SLT, I1Ty, LV2))); + EXPECT_TRUE(isa(LV1.getCompare(CmpInst::ICMP_SLE, I1Ty, LV2))); + EXPECT_TRUE(isa(LV1.getCompare(CmpInst::ICMP_NE, I1Ty, LV2))); + EXPECT_TRUE(isa(LV1.getCompare(CmpInst::ICMP_EQ, I1Ty, LV2))); + EXPECT_TRUE(isa(LV1.getCompare(CmpInst::ICMP_SGE, I1Ty, LV2))); + EXPECT_TRUE(isa(LV1.getCompare(CmpInst::ICMP_SGT, I1Ty, LV2))); + + auto *FloatTy = IntegerType::getFloatTy(Context); + auto LV3 = ValueLatticeElement::get(ConstantFP::get(FloatTy, 1.0)); + EXPECT_TRUE(isa(LV1.getCompare(CmpInst::FCMP_OEQ, I1Ty, LV3))); + EXPECT_TRUE(isa(LV1.getCompare(CmpInst::FCMP_OGE, I1Ty, LV3))); + EXPECT_TRUE(isa(LV1.getCompare(CmpInst::FCMP_OLE, I1Ty, LV3))); + EXPECT_TRUE(isa(LV1.getCompare(CmpInst::FCMP_ONE, I1Ty, LV3))); + EXPECT_TRUE(isa(LV1.getCompare(CmpInst::FCMP_OLT, I1Ty, LV3))); + EXPECT_TRUE(isa(LV1.getCompare(CmpInst::FCMP_OGT, I1Ty, LV3))); } } // end anonymous namespace