diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -5532,48 +5532,58 @@ switch (ICI->getPredicate()) { case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - std::swap(LHS, RHS); - LLVM_FALLTHROUGH; - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - // a >s b ? a+x : b+x -> smax(a, b)+x - // a >s b ? b+x : a+x -> smin(a, b)+x - if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) { - const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), I->getType()); - const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), I->getType()); - const SCEV *LA = getSCEV(TrueVal); - const SCEV *RA = getSCEV(FalseVal); - const SCEV *LDiff = getMinusSCEV(LA, LS); - const SCEV *RDiff = getMinusSCEV(RA, RS); - if (LDiff == RDiff) - return getAddExpr(getSMaxExpr(LS, RS), LDiff); - LDiff = getMinusSCEV(LA, RS); - RDiff = getMinusSCEV(RA, LS); - if (LDiff == RDiff) - return getAddExpr(getSMinExpr(LS, RS), LDiff); - } - break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: std::swap(LHS, RHS); LLVM_FALLTHROUGH; + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: - // a >u b ? a+x : b+x -> umax(a, b)+x - // a >u b ? b+x : a+x -> umin(a, b)+x + // a > b ? a+x : b+x -> max(a, b)+x + // a > b ? b+x : a+x -> min(a, b)+x if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) { - const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); - const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), I->getType()); const SCEV *LA = getSCEV(TrueVal); const SCEV *RA = getSCEV(FalseVal); + const SCEV *LS = getSCEV(LHS); + const SCEV *RS = getSCEV(RHS); + if (LA->getType()->isPointerTy()) { + // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA. + // Need to make sure we can't produce weird expressions involving + // negated pointers. + const SCEV *LS = getSCEV(LHS); + const SCEV *RS = getSCEV(RHS); + if (LA == LS && RA == RS) + return ICI->isSigned() ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS); + if (LA == RS && RA == LS) + return ICI->isSigned() ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS); + } + auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * { + if (Op->getType()->isPointerTy()) { + Op = getLosslessPtrToIntExpr(Op); + if (isa(Op)) + return Op; + } + if (ICI->isSigned()) + Op = getNoopOrSignExtend(Op, I->getType()); + else + Op = getNoopOrZeroExtend(Op, I->getType()); + return Op; + }; + LS = CoerceOperand(LS); + RS = CoerceOperand(RS); + if (isa(LS) || isa(RS)) + break; const SCEV *LDiff = getMinusSCEV(LA, LS); const SCEV *RDiff = getMinusSCEV(RA, RS); if (LDiff == RDiff) - return getAddExpr(getUMaxExpr(LS, RS), LDiff); + return getAddExpr( + ICI->isSigned() ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS), LDiff); LDiff = getMinusSCEV(LA, RS); RDiff = getMinusSCEV(RA, LS); if (LDiff == RDiff) - return getAddExpr(getUMinExpr(LS, RS), LDiff); + return getAddExpr( + ICI->isSigned() ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS), LDiff); } break; case ICmpInst::ICMP_NE: diff --git a/llvm/test/Analysis/ScalarEvolution/pr46786.ll b/llvm/test/Analysis/ScalarEvolution/pr46786.ll --- a/llvm/test/Analysis/ScalarEvolution/pr46786.ll +++ b/llvm/test/Analysis/ScalarEvolution/pr46786.ll @@ -16,11 +16,11 @@ ; CHECK-NEXT: %i5 = getelementptr inbounds i8, i8* %i, i32 %i4 ; CHECK-NEXT: --> ((-1 * %arg1) + %arg2 + %arg) U: full-set S: full-set ; CHECK-NEXT: %i7 = select i1 %i6, i32 %arg2, i32 %arg1 -; CHECK-NEXT: --> ((-1 * %arg) + (((-1 * %arg1) + %arg2 + %arg) umin %arg) + %arg1) U: full-set S: full-set +; CHECK-NEXT: --> ((-1 * (ptrtoint i8* %arg to i32)) + (((-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg2) umin (ptrtoint i8* %arg to i32)) + %arg1) U: full-set S: full-set ; CHECK-NEXT: %i8 = sub i32 %arg3, %i7 -; CHECK-NEXT: --> ((-1 * (((-1 * %arg1) + %arg2 + %arg) umin %arg)) + (-1 * %arg1) + %arg3 + %arg) U: full-set S: full-set +; CHECK-NEXT: --> ((-1 * (((-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg2) umin (ptrtoint i8* %arg to i32))) + (-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg3) U: full-set S: full-set ; CHECK-NEXT: %i9 = getelementptr inbounds i8, i8* %arg, i32 %i8 -; CHECK-NEXT: --> ((2 * %arg) + (-1 * (((-1 * %arg1) + %arg2 + %arg) umin %arg)) + (-1 * %arg1) + %arg3) U: full-set S: full-set +; CHECK-NEXT: --> ((-1 * (((-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg2) umin (ptrtoint i8* %arg to i32))) + (-1 * %arg1) + (ptrtoint i8* %arg to i32) + %arg3 + %arg) U: full-set S: full-set ; CHECK-NEXT: Determining loop execution counts for: @FSE_decompress_usingDTable ; bb: @@ -42,11 +42,11 @@ ; CHECK-NEXT: %p2 = getelementptr i8, i8* %p, i32 1 ; CHECK-NEXT: --> (1 + %p) U: full-set S: full-set ; CHECK-NEXT: %index = select i1 %cmp, i32 2, i32 1 -; CHECK-NEXT: --> ((-1 * %p) + ((1 + %p) umax (2 + %p))) U: full-set S: full-set +; CHECK-NEXT: --> ((-1 * (ptrtoint i8* %p to i32)) + ((1 + (ptrtoint i8* %p to i32)) umax (2 + (ptrtoint i8* %p to i32)))) U: full-set S: full-set ; CHECK-NEXT: %neg_index = sub i32 0, %index -; CHECK-NEXT: --> ((-1 * ((1 + %p) umax (2 + %p))) + %p) U: full-set S: full-set +; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) umax (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32)) U: full-set S: full-set ; CHECK-NEXT: %gep = getelementptr i8, i8* %p, i32 %neg_index -; CHECK-NEXT: --> ((2 * %p) + (-1 * ((1 + %p) umax (2 + %p)))) U: full-set S: full-set +; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) umax (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32) + %p) U: full-set S: full-set ; CHECK-NEXT: Determining loop execution counts for: @test_01 ; %p1 = getelementptr i8, i8* %p, i32 2 @@ -66,11 +66,11 @@ ; CHECK-NEXT: %p2 = getelementptr i8, i8* %p, i32 1 ; CHECK-NEXT: --> (1 + %p) U: full-set S: full-set ; CHECK-NEXT: %index = select i1 %cmp, i32 2, i32 1 -; CHECK-NEXT: --> ((-1 * %p) + ((1 + %p) smax (2 + %p))) U: full-set S: full-set +; CHECK-NEXT: --> ((-1 * (ptrtoint i8* %p to i32)) + ((1 + (ptrtoint i8* %p to i32)) smax (2 + (ptrtoint i8* %p to i32)))) U: full-set S: full-set ; CHECK-NEXT: %neg_index = sub i32 0, %index -; CHECK-NEXT: --> ((-1 * ((1 + %p) smax (2 + %p))) + %p) U: full-set S: full-set +; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) smax (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32)) U: full-set S: full-set ; CHECK-NEXT: %gep = getelementptr i8, i8* %p, i32 %neg_index -; CHECK-NEXT: --> ((2 * %p) + (-1 * ((1 + %p) smax (2 + %p)))) U: full-set S: full-set +; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) smax (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32) + %p) U: full-set S: full-set ; CHECK-NEXT: Determining loop execution counts for: @test_02 ; %p1 = getelementptr i8, i8* %p, i32 2 @@ -90,11 +90,11 @@ ; CHECK-NEXT: %p2 = getelementptr i8, i8* %p, i32 1 ; CHECK-NEXT: --> (1 + %p) U: full-set S: full-set ; CHECK-NEXT: %index = select i1 %cmp, i32 2, i32 1 -; CHECK-NEXT: --> ((-1 * %p) + ((1 + %p) umin (2 + %p))) U: full-set S: full-set +; CHECK-NEXT: --> ((-1 * (ptrtoint i8* %p to i32)) + ((1 + (ptrtoint i8* %p to i32)) umin (2 + (ptrtoint i8* %p to i32)))) U: full-set S: full-set ; CHECK-NEXT: %neg_index = sub i32 0, %index -; CHECK-NEXT: --> ((-1 * ((1 + %p) umin (2 + %p))) + %p) U: full-set S: full-set +; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) umin (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32)) U: full-set S: full-set ; CHECK-NEXT: %gep = getelementptr i8, i8* %p, i32 %neg_index -; CHECK-NEXT: --> ((2 * %p) + (-1 * ((1 + %p) umin (2 + %p)))) U: full-set S: full-set +; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) umin (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32) + %p) U: full-set S: full-set ; CHECK-NEXT: Determining loop execution counts for: @test_03 ; %p1 = getelementptr i8, i8* %p, i32 2 @@ -114,11 +114,11 @@ ; CHECK-NEXT: %p2 = getelementptr i8, i8* %p, i32 1 ; CHECK-NEXT: --> (1 + %p) U: full-set S: full-set ; CHECK-NEXT: %index = select i1 %cmp, i32 2, i32 1 -; CHECK-NEXT: --> ((-1 * %p) + ((1 + %p) smin (2 + %p))) U: full-set S: full-set +; CHECK-NEXT: --> ((-1 * (ptrtoint i8* %p to i32)) + ((1 + (ptrtoint i8* %p to i32)) smin (2 + (ptrtoint i8* %p to i32)))) U: full-set S: full-set ; CHECK-NEXT: %neg_index = sub i32 0, %index -; CHECK-NEXT: --> ((-1 * ((1 + %p) smin (2 + %p))) + %p) U: full-set S: full-set +; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) smin (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32)) U: full-set S: full-set ; CHECK-NEXT: %gep = getelementptr i8, i8* %p, i32 %neg_index -; CHECK-NEXT: --> ((2 * %p) + (-1 * ((1 + %p) smin (2 + %p)))) U: full-set S: full-set +; CHECK-NEXT: --> ((-1 * ((1 + (ptrtoint i8* %p to i32)) smin (2 + (ptrtoint i8* %p to i32)))) + (ptrtoint i8* %p to i32) + %p) U: full-set S: full-set ; CHECK-NEXT: Determining loop execution counts for: @test_04 ; %p1 = getelementptr i8, i8* %p, i32 2 diff --git a/llvm/test/Transforms/IndVarSimplify/pr45835.ll b/llvm/test/Transforms/IndVarSimplify/pr45835.ll --- a/llvm/test/Transforms/IndVarSimplify/pr45835.ll +++ b/llvm/test/Transforms/IndVarSimplify/pr45835.ll @@ -10,7 +10,7 @@ define internal fastcc void @d(i8* %c) unnamed_addr #0 { entry: - %cmp = icmp ule i8* %c, getelementptr inbounds (i8, i8* @a, i64 65535) + %cmp = icmp ule i8* %c, @a %add.ptr = getelementptr inbounds i8, i8* %c, i64 -65535 br label %while.cond @@ -18,7 +18,7 @@ br i1 icmp ne (i8 0, i8 0), label %cont, label %while.end cont: - %a.mux = select i1 %cmp, i8* @a, i8* %add.ptr + %a.mux = select i1 %cmp, i8* @a, i8* %c switch i64 0, label %while.cond [ i64 -1, label %handler.pointer_overflow.i i64 0, label %handler.pointer_overflow.i @@ -26,7 +26,7 @@ handler.pointer_overflow.i: %a.mux.lcssa4 = phi i8* [ %a.mux, %cont ], [ %a.mux, %cont ] -; ALWAYS: [ %scevgep, %cont ], [ %scevgep, %cont ] +; ALWAYS: [ %umax, %cont ], [ %umax, %cont ] ; NEVER: [ %a.mux, %cont ], [ %a.mux, %cont ] ; In cheap mode, use either one as long as it's consistent. ; CHEAP: [ %[[VAL:.*]], %cont ], [ %[[VAL]], %cont ] diff --git a/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp b/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp --- a/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp +++ b/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp @@ -118,20 +118,7 @@ ScalarEvolution SE = buildSE(*F); auto *S = SE.getSCEV(CastB); - SCEVExpander Exp(SE, M.getDataLayout(), "expander"); - Value *V = - Exp.expandCodeFor(cast(S)->getOperand(1), nullptr, Br); - - // Expect the expansion code contains: - // %0 = bitcast i32* %bitcast2 to i8* - // %uglygep = getelementptr i8, i8* %0, i64 -1 - // %1 = bitcast i8* %uglygep to i32* - EXPECT_TRUE(isa(V)); - Instruction *Gep = cast(V)->getPrevNode(); - EXPECT_TRUE(isa(Gep)); - EXPECT_TRUE(isa(Gep->getOperand(1))); - EXPECT_EQ(cast(Gep->getOperand(1))->getSExtValue(), -1); - EXPECT_TRUE(isa(Gep->getPrevNode())); + EXPECT_TRUE(isa(S)); } // Make sure that SCEV doesn't introduce illegal ptrtoint/inttoptr instructions