diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -7854,7 +7854,7 @@ } break; - case Instruction::AShr: { + case Instruction::AShr: // AShr X, C, where C is a constant. ConstantInt *CI = dyn_cast(BO->RHS); if (!CI) @@ -7876,37 +7876,69 @@ Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt); Operator *L = dyn_cast(BO->LHS); - if (L && L->getOpcode() == Instruction::Shl) { + const SCEV *AddTruncateExpr = nullptr; + ConstantInt *ShlAmtCI = nullptr; + const SCEV *AddConstant = nullptr; + + if (L && L->getOpcode() == Instruction::Add) { + // X = Shl A, n + // Y = Add X, c + // Z = AShr Y, m + // n, c and m are constants. + + Operator *LShift = dyn_cast(L->getOperand(0)); + ConstantInt *AddOperandCI = dyn_cast(L->getOperand(1)); + if (LShift && LShift->getOpcode() == Instruction::Shl) { + if (AddOperandCI) { + const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0)); + ShlAmtCI = dyn_cast(LShift->getOperand(1)); + // since we truncate to TruncTy, the AddConstant should be of the + // same type, so create a new Constant with type same as TruncTy. + // Also, the Add constant should be shifted right by AShr amount. + APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt); + AddConstant = getConstant(TruncTy, AddOperand.getZExtValue(), + AddOperand.isSignBitSet()); + // we model the expression as sext(add(trunc(A), c << n)), since the + // sext(trunc) part is already handled below, we create a + // AddExpr(TruncExp) which will be used later. + AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy); + } + } + } else if (L && L->getOpcode() == Instruction::Shl) { // X = Shl A, n // Y = AShr X, m // Both n and m are constant. const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0)); - if (L->getOperand(1) == BO->RHS) - // For a two-shift sext-inreg, i.e. n = m, - // use sext(trunc(x)) as the SCEV expression. - return getSignExtendExpr( - getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy); - - ConstantInt *ShlAmtCI = dyn_cast(L->getOperand(1)); - if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) { - uint64_t ShlAmt = ShlAmtCI->getZExtValue(); - if (ShlAmt > AShrAmt) { - // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV - // expression. We already checked that ShlAmt < BitWidth, so - // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as - // ShlAmt - AShrAmt < Amt. - APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt, - ShlAmt - AShrAmt); - return getSignExtendExpr( - getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy), - getConstant(Mul)), OuterTy); - } + ShlAmtCI = dyn_cast(L->getOperand(1)); + AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy); + } + + if (AddTruncateExpr && ShlAmtCI) { + // We can merge the two given cases into a single SCEV statement, + // incase n = m, the mul expression will be 2^0, so it gets resolved to + // a simpler case. The following code handles the two cases: + // + // 1) For a two-shift sext-inreg, i.e. n = m, + // use sext(trunc(x)) as the SCEV expression. + // + // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV + // expression. We already checked that ShlAmt < BitWidth, so + // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as + // ShlAmt - AShrAmt < Amt. + uint64_t ShlAmt = ShlAmtCI->getZExtValue(); + if (ShlAmtCI->getValue().ult(BitWidth) && ShlAmt >= AShrAmt) { + APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt, ShlAmt - AShrAmt); + const SCEV *CompositeExpr = + getMulExpr(AddTruncateExpr, getConstant(Mul)); + if (L->getOpcode() != Instruction::Shl) + CompositeExpr = getAddExpr(CompositeExpr, AddConstant); + + return getSignExtendExpr(CompositeExpr, OuterTy); } } break; } - } } switch (U->getOpcode()) { diff --git a/llvm/test/Analysis/ScalarEvolution/sext-add-inreg-loop.ll b/llvm/test/Analysis/ScalarEvolution/sext-add-inreg-loop.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Analysis/ScalarEvolution/sext-add-inreg-loop.ll @@ -0,0 +1,52 @@ +; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 2 +; RUN: opt < %s -disable-output "-passes=print" 2>&1 | FileCheck %s + +@.str = private unnamed_addr constant [3 x i8] c"%x\00", align 1 + +define dso_local i32 @test_loop(ptr nocapture noundef readonly %x) { +; CHECK-LABEL: 'test_loop' +; CHECK-NEXT: Classifying expressions for: @test_loop +; CHECK-NEXT: %i.03 = phi i64 [ 1, %entry ], [ %inc, %for.body ] +; CHECK-NEXT: --> {1,+,1}<%for.body> U: [1,10) S: [1,10) Exits: 9 LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: %conv = shl nuw nsw i64 %i.03, 32 +; CHECK-NEXT: --> {4294967296,+,4294967296}<%for.body> U: [4294967296,38654705665) S: [4294967296,38654705665) Exits: 38654705664 LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: %sext = add nsw i64 %conv, -4294967296 +; CHECK-NEXT: --> {0,+,4294967296}<%for.body> U: [0,34359738369) S: [0,34359738369) Exits: 34359738368 LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: %idxprom = ashr exact i64 %sext, 32 +; CHECK-NEXT: --> {0,+,1}<%for.body> U: [0,9) S: [0,9) Exits: 8 LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %x, i64 %idxprom +; CHECK-NEXT: --> {%x,+,4}<%for.body> U: full-set S: full-set Exits: (32 + %x) LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: %0 = load i32, ptr %arrayidx, align 4 +; CHECK-NEXT: --> %0 U: full-set S: full-set Exits: <> LoopDispositions: { %for.body: Variant } +; CHECK-NEXT: %call = tail call i32 (ptr, ...) @printf(ptr noundef nonnull dereferenceable(1) @.str, i32 noundef %0) +; CHECK-NEXT: --> %call U: full-set S: full-set Exits: <> LoopDispositions: { %for.body: Variant } +; CHECK-NEXT: %inc = add nuw nsw i64 %i.03, 1 +; CHECK-NEXT: --> {2,+,1}<%for.body> U: [2,11) S: [2,11) Exits: 10 LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: Determining loop execution counts for: @test_loop +; CHECK-NEXT: Loop %for.body: backedge-taken count is 8 +; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is 8 +; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is 8 +; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is 8 +; CHECK-NEXT: Predicates: +; CHECK: Loop %for.body: Trip multiple is 9 +; +entry: + br label %for.body + +for.cond.cleanup: ; preds = %for.body + ret i32 0 + +for.body: ; preds = %entry, %for.body + %i.03 = phi i64 [ 1, %entry ], [ %inc, %for.body ] + %conv = shl nuw nsw i64 %i.03, 32 + %sext = add nsw i64 %conv, -4294967296 + %idxprom = ashr exact i64 %sext, 32 + %arrayidx = getelementptr inbounds i32, ptr %x, i64 %idxprom + %0 = load i32, ptr %arrayidx, align 4 + %call = tail call i32 (ptr, ...) @printf(ptr noundef nonnull dereferenceable(1) @.str, i32 noundef %0) + %inc = add nuw nsw i64 %i.03, 1 + %exitcond.not = icmp eq i64 %inc, 10 + br i1 %exitcond.not, label %for.cond.cleanup, label %for.body +} + +declare noundef i32 @printf(ptr nocapture noundef readonly, ...) diff --git a/llvm/test/Analysis/ScalarEvolution/sext-add-inreg-unequal.ll b/llvm/test/Analysis/ScalarEvolution/sext-add-inreg-unequal.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Analysis/ScalarEvolution/sext-add-inreg-unequal.ll @@ -0,0 +1,53 @@ +; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 2 +; RUN: opt < %s -disable-output "-passes=print" 2>&1 | FileCheck %s + +define i64 @test00(i64 %a) { +; CHECK-LABEL: 'test00' +; CHECK-NEXT: Classifying expressions for: @test00 +; CHECK-NEXT: %add = shl i64 %a, 10 +; CHECK-NEXT: --> (1024 * %a) U: [0,-1023) S: [-9223372036854775808,9223372036854774785) +; CHECK-NEXT: %shl = add i64 %add, 256 +; CHECK-NEXT: --> (256 + (1024 * %a)) U: [256,-767) S: [-9223372036854775552,9223372036854775041) +; CHECK-NEXT: %ashr = ashr exact i64 %shl, 8 +; CHECK-NEXT: --> (1 + (sext i56 (4 * (trunc i64 %a to i56)) to i64)) U: [1,-2) S: [-36028797018963967,36028797018963966) +; CHECK-NEXT: Determining loop execution counts for: @test00 +; + %add = shl i64 %a, 10 + %shl = add i64 %add, 256 + %ashr = ashr exact i64 %shl, 8 + ret i64 %ashr +} + +define i64 @test01(i64 %a) { +; CHECK-LABEL: 'test01' +; CHECK-NEXT: Classifying expressions for: @test01 +; CHECK-NEXT: %add = shl i64 %a, 6 +; CHECK-NEXT: --> (64 * %a) U: [0,-63) S: [-9223372036854775808,9223372036854775745) +; CHECK-NEXT: %shl = add i64 %add, 256 +; CHECK-NEXT: --> (256 + (64 * %a)) U: [0,-63) S: [-9223372036854775808,9223372036854775745) +; CHECK-NEXT: %ashr = ashr exact i64 %shl, 8 +; CHECK-NEXT: --> %ashr U: [-36028797018963968,36028797018963968) S: [-36028797018963968,36028797018963968) +; CHECK-NEXT: Determining loop execution counts for: @test01 +; + %add = shl i64 %a, 6 + %shl = add i64 %add, 256 + %ashr = ashr exact i64 %shl, 8 + ret i64 %ashr +} + +define i64 @test02(i64 %a) { +; CHECK-LABEL: 'test02' +; CHECK-NEXT: Classifying expressions for: @test02 +; CHECK-NEXT: %add = shl i64 %a, 12 +; CHECK-NEXT: --> (4096 * %a) U: [0,-4095) S: [-9223372036854775808,9223372036854771713) +; CHECK-NEXT: %shl = add i64 %add, 4096 +; CHECK-NEXT: --> (4096 + (4096 * %a)) U: [0,-4095) S: [-9223372036854775808,9223372036854771713) +; CHECK-NEXT: %ashr = ashr exact i64 %shl, 8 +; CHECK-NEXT: --> (sext i56 (16 + (16 * (trunc i64 %a to i56))) to i64) U: [0,-15) S: [-36028797018963968,36028797018963953) +; CHECK-NEXT: Determining loop execution counts for: @test02 +; + %add = shl i64 %a, 12 + %shl = add i64 %add, 4096 + %ashr = ashr exact i64 %shl, 8 + ret i64 %ashr +} diff --git a/llvm/test/Analysis/ScalarEvolution/sext-add-inreg.ll b/llvm/test/Analysis/ScalarEvolution/sext-add-inreg.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Analysis/ScalarEvolution/sext-add-inreg.ll @@ -0,0 +1,19 @@ +; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 2 +; RUN: opt < %s -disable-output "-passes=print" 2>&1 | FileCheck %s + +define i64 @test(i64 %a) { +; CHECK-LABEL: 'test' +; CHECK-NEXT: Classifying expressions for: @test +; CHECK-NEXT: %add = shl i64 %a, 8 +; CHECK-NEXT: --> (256 * %a) U: [0,-255) S: [-9223372036854775808,9223372036854775553) +; CHECK-NEXT: %shl = add i64 %add, 256 +; CHECK-NEXT: --> (256 + (256 * %a)) U: [0,-255) S: [-9223372036854775808,9223372036854775553) +; CHECK-NEXT: %ashr = ashr exact i64 %shl, 8 +; CHECK-NEXT: --> (sext i56 (1 + (trunc i64 %a to i56)) to i64) U: [-36028797018963968,36028797018963968) S: [-36028797018963968,36028797018963968) +; CHECK-NEXT: Determining loop execution counts for: @test +; + %add = shl i64 %a, 8 + %shl = add i64 %add, 256 + %ashr = ashr exact i64 %shl, 8 + ret i64 %ashr +} diff --git a/llvm/test/Transforms/LoopStrengthReduce/scaling-factor-incompat-type.ll b/llvm/test/Transforms/LoopStrengthReduce/scaling-factor-incompat-type.ll --- a/llvm/test/Transforms/LoopStrengthReduce/scaling-factor-incompat-type.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/scaling-factor-incompat-type.ll @@ -4,7 +4,6 @@ ; see pr42770 ; REQUIRES: asserts ; RUN: opt < %s -loop-reduce -S | FileCheck %s - target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128-ni:1" define void @foo() { @@ -12,23 +11,23 @@ ; CHECK-NEXT: bb: ; CHECK-NEXT: br label [[BB4:%.*]] ; CHECK: bb1: -; CHECK-NEXT: [[T3:%.*]] = ashr i64 [[LSR_IV_NEXT:%.*]], 32 +; CHECK-NEXT: [[T:%.*]] = shl i64 [[T14:%.*]], 32 +; CHECK-NEXT: [[T2:%.*]] = add i64 [[T]], 1 +; CHECK-NEXT: [[T3:%.*]] = ashr i64 [[T2]], 32 ; CHECK-NEXT: ret void ; CHECK: bb4: -; CHECK-NEXT: [[LSR_IV1:%.*]] = phi i16 [ [[LSR_IV_NEXT2:%.*]], [[BB13:%.*]] ], [ 6, [[BB:%.*]] ] -; CHECK-NEXT: [[LSR_IV:%.*]] = phi i64 [ [[LSR_IV_NEXT]], [[BB13]] ], [ 8589934593, [[BB]] ] -; CHECK-NEXT: [[T5:%.*]] = phi i64 [ 2, [[BB]] ], [ [[T14:%.*]], [[BB13]] ] +; CHECK-NEXT: [[LSR_IV:%.*]] = phi i16 [ [[LSR_IV_NEXT:%.*]], [[BB13:%.*]] ], [ 6, [[BB:%.*]] ] +; CHECK-NEXT: [[T5:%.*]] = phi i64 [ 2, [[BB]] ], [ [[T14]], [[BB13]] ] ; CHECK-NEXT: [[T6:%.*]] = add i64 [[T5]], 4 ; CHECK-NEXT: [[T7:%.*]] = trunc i64 [[T6]] to i16 ; CHECK-NEXT: [[T8:%.*]] = urem i16 [[T7]], 3 ; CHECK-NEXT: [[T9:%.*]] = mul i16 [[T8]], 2 -; CHECK-NEXT: [[LSR_IV_NEXT]] = add nuw nsw i64 [[LSR_IV]], 25769803776 -; CHECK-NEXT: [[LSR_IV_NEXT2]] = add nuw nsw i16 [[LSR_IV1]], 6 +; CHECK-NEXT: [[LSR_IV_NEXT]] = add nuw nsw i16 [[LSR_IV]], 6 ; CHECK-NEXT: [[T14]] = add nuw nsw i64 [[T5]], 6 ; CHECK-NEXT: [[T10:%.*]] = icmp eq i16 [[T9]], 1 ; CHECK-NEXT: br i1 [[T10]], label [[BB11:%.*]], label [[BB13]] ; CHECK: bb11: -; CHECK-NEXT: [[T12:%.*]] = udiv i16 1, [[LSR_IV1]] +; CHECK-NEXT: [[T12:%.*]] = udiv i16 1, [[LSR_IV]] ; CHECK-NEXT: unreachable ; CHECK: bb13: ; CHECK-NEXT: br i1 true, label [[BB1:%.*]], label [[BB4]]