diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -7868,19 +7868,51 @@ Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt); Operator *L = dyn_cast(BO->LHS); - if (L && L->getOpcode() == Instruction::Shl) { + const SCEV *AddTruncateExpr = nullptr; + ConstantInt *ShlAmtCI = nullptr; + + if (L && L->getOpcode() == Instruction::Add) { + // X = Shl A, n + // Y = Add X, c + // Z = AShr Y, m + // n, c and m are constants. + + Operator *LShift = dyn_cast(L->getOperand(0)); + ConstantInt *AddOperandCI = dyn_cast(L->getOperand(1)); + if (LShift && LShift->getOpcode() == Instruction::Shl) { + if (AddOperandCI) { + ShlAmtCI = dyn_cast(LShift->getOperand(1)); + const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0)); + // since we truncate to TruncTy, the AddConstant should be of the + // same type, so create a new Constant with type same as TruncTy. + // Also, the Add constant should be shifted right by Shl amount. + APInt AddOperand = + AddOperandCI->getValue().ashr(ShlAmtCI->getZExtValue()); + const SCEV *AddConstant = getConstant(AddOperand); + + // we model the expression as sext(add(trunc(A), c << n)), since the + // sext(trunc) part is already handled below, we create a + // AddExpr(TruncExp) which will be used later. + AddTruncateExpr = + getTruncateExpr(getAddExpr(ShlOp0SCEV, AddConstant), TruncTy); + } + } + } else if (L && L->getOpcode() == Instruction::Shl) { // X = Shl A, n // Y = AShr X, m // Both n and m are constant. + ShlAmtCI = dyn_cast(L->getOperand(1)); const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0)); - if (L->getOperand(1) == BO->RHS) + AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy); + } + + if (AddTruncateExpr && ShlAmtCI) { + if (ShlAmtCI == CI) // For a two-shift sext-inreg, i.e. n = m, // use sext(trunc(x)) as the SCEV expression. - return getSignExtendExpr( - getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy); + return getSignExtendExpr(AddTruncateExpr, OuterTy); - ConstantInt *ShlAmtCI = dyn_cast(L->getOperand(1)); if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) { uint64_t ShlAmt = ShlAmtCI->getZExtValue(); if (ShlAmt > AShrAmt) { @@ -7891,8 +7923,7 @@ APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt, ShlAmt - AShrAmt); return getSignExtendExpr( - getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy), - getConstant(Mul)), OuterTy); + getMulExpr(AddTruncateExpr, getConstant(Mul)), OuterTy); } } } diff --git a/llvm/test/Analysis/ScalarEvolution/sext-add-inreg.ll b/llvm/test/Analysis/ScalarEvolution/sext-add-inreg.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Analysis/ScalarEvolution/sext-add-inreg.ll @@ -0,0 +1,52 @@ +; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 2 +; RUN: opt < %s -disable-output "-passes=print" 2>&1 | FileCheck %s + +@.str = private unnamed_addr constant [3 x i8] c"%x\00", align 1 + +define dso_local i32 @test_loop(ptr nocapture noundef readonly %x) { +; CHECK-LABEL: 'test_loop' +; CHECK-NEXT: Classifying expressions for: @test_loop +; CHECK-NEXT: %i.03 = phi i64 [ 1, %entry ], [ %inc, %for.body ] +; CHECK-NEXT: --> {1,+,1}<%for.body> U: [1,10) S: [1,10) Exits: 9 LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: %conv = shl nuw nsw i64 %i.03, 32 +; CHECK-NEXT: --> {4294967296,+,4294967296}<%for.body> U: [4294967296,38654705665) S: [4294967296,38654705665) Exits: 38654705664 LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: %sext = add nsw i64 %conv, -4294967296 +; CHECK-NEXT: --> {0,+,4294967296}<%for.body> U: [0,34359738369) S: [0,34359738369) Exits: 34359738368 LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: %idxprom = ashr exact i64 %sext, 32 +; CHECK-NEXT: --> {0,+,1}<%for.body> U: [0,9) S: [0,9) Exits: 8 LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %x, i64 %idxprom +; CHECK-NEXT: --> {%x,+,4}<%for.body> U: full-set S: full-set Exits: (32 + %x) LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: %0 = load i32, ptr %arrayidx, align 4 +; CHECK-NEXT: --> %0 U: full-set S: full-set Exits: <> LoopDispositions: { %for.body: Variant } +; CHECK-NEXT: %call = tail call i32 (ptr, ...) @printf(ptr noundef nonnull dereferenceable(1) @.str, i32 noundef %0) +; CHECK-NEXT: --> %call U: full-set S: full-set Exits: <> LoopDispositions: { %for.body: Variant } +; CHECK-NEXT: %inc = add nuw nsw i64 %i.03, 1 +; CHECK-NEXT: --> {2,+,1}<%for.body> U: [2,11) S: [2,11) Exits: 10 LoopDispositions: { %for.body: Computable } +; CHECK-NEXT: Determining loop execution counts for: @test_loop +; CHECK-NEXT: Loop %for.body: backedge-taken count is 8 +; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is 8 +; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is 8 +; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is 8 +; CHECK-NEXT: Predicates: +; CHECK: Loop %for.body: Trip multiple is 9 +; +entry: + br label %for.body + +for.cond.cleanup: ; preds = %for.body + ret i32 0 + +for.body: ; preds = %entry, %for.body + %i.03 = phi i64 [ 1, %entry ], [ %inc, %for.body ] + %conv = shl nuw nsw i64 %i.03, 32 + %sext = add nsw i64 %conv, -4294967296 + %idxprom = ashr exact i64 %sext, 32 + %arrayidx = getelementptr inbounds i32, ptr %x, i64 %idxprom + %0 = load i32, ptr %arrayidx, align 4 + %call = tail call i32 (ptr, ...) @printf(ptr noundef nonnull dereferenceable(1) @.str, i32 noundef %0) + %inc = add nuw nsw i64 %i.03, 1 + %exitcond.not = icmp eq i64 %inc, 10 + br i1 %exitcond.not, label %for.cond.cleanup, label %for.body +} + +declare noundef i32 @printf(ptr nocapture noundef readonly, ...) diff --git a/llvm/test/Transforms/LoopStrengthReduce/scaling-factor-incompat-type.ll b/llvm/test/Transforms/LoopStrengthReduce/scaling-factor-incompat-type.ll --- a/llvm/test/Transforms/LoopStrengthReduce/scaling-factor-incompat-type.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/scaling-factor-incompat-type.ll @@ -4,7 +4,6 @@ ; see pr42770 ; REQUIRES: asserts ; RUN: opt < %s -loop-reduce -S | FileCheck %s - target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128-ni:1" define void @foo() { @@ -12,17 +11,19 @@ ; CHECK-NEXT: bb: ; CHECK-NEXT: br label [[BB4:%.*]] ; CHECK: bb1: -; CHECK-NEXT: [[T3:%.*]] = ashr i64 [[LSR_IV_NEXT:%.*]], 32 +; CHECK-NEXT: [[T:%.*]] = shl i64 [[T14:%.*]], 32 +; CHECK-NEXT: [[T2:%.*]] = add i64 [[T]], 1 +; CHECK-NEXT: [[T3:%.*]] = ashr i64 [[T2]], 32 ; CHECK-NEXT: ret void ; CHECK: bb4: -; CHECK-NEXT: [[LSR_IV1:%.*]] = phi i16 [ [[LSR_IV_NEXT2:%.*]], [[BB13:%.*]] ], [ 6, [[BB:%.*]] ] -; CHECK-NEXT: [[LSR_IV:%.*]] = phi i64 [ [[LSR_IV_NEXT]], [[BB13]] ], [ 8589934593, [[BB]] ] -; CHECK-NEXT: [[LSR_IV_NEXT]] = add nuw nsw i64 [[LSR_IV]], 25769803776 -; CHECK-NEXT: [[LSR_IV_NEXT2]] = add nuw nsw i16 [[LSR_IV1]], 6 +; CHECK-NEXT: [[LSR_IV:%.*]] = phi i16 [ [[LSR_IV_NEXT:%.*]], [[BB13:%.*]] ], [ 6, [[BB:%.*]] ] +; CHECK-NEXT: [[T5:%.*]] = phi i64 [ 2, [[BB]] ], [ [[T14]], [[BB13]] ] +; CHECK-NEXT: [[LSR_IV_NEXT]] = add nuw nsw i16 [[LSR_IV]], 6 +; CHECK-NEXT: [[T14]] = add nuw nsw i64 [[T5]], 6 ; CHECK-NEXT: [[T10:%.*]] = icmp eq i16 1, 0 ; CHECK-NEXT: br i1 [[T10]], label [[BB11:%.*]], label [[BB13]] ; CHECK: bb11: -; CHECK-NEXT: [[T12:%.*]] = udiv i16 1, [[LSR_IV1]] +; CHECK-NEXT: [[T12:%.*]] = udiv i16 1, [[LSR_IV]] ; CHECK-NEXT: unreachable ; CHECK: bb13: ; CHECK-NEXT: br i1 true, label [[BB1:%.*]], label [[BB4]]