Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -5254,28 +5254,59 @@ break; case Instruction::AShr: - // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. - if (ConstantInt *CI = dyn_cast(BO->RHS)) - if (Operator *L = dyn_cast(BO->LHS)) - if (L->getOpcode() == Instruction::Shl && - L->getOperand(1) == BO->RHS) { - uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType()); - - // If the shift count is not less than the bitwidth, the result of - // the shift is undefined. Don't try to analyze it, because the - // resolution chosen here may differ from the resolution chosen in - // other parts of the compiler. - if (CI->getValue().uge(BitWidth)) - break; + ConstantInt *CI = dyn_cast(BO->RHS); + Operator *L = dyn_cast(BO->LHS); + if (CI && L) { + if (L->getOpcode() == Instruction::Shl) { + uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType()); + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (CI->getValue().uge(BitWidth)) + break; - uint64_t Amt = BitWidth - CI->getZExtValue(); - if (Amt == BitWidth) - return getSCEV(L->getOperand(0)); // shift by zero --> noop + const SCEV *LShOp0SCEV = getSCEV(L->getOperand(0)); + if (CI->isNullValue()) + return LShOp0SCEV; // shift by zero --> noop + + uint64_t AShrAmt = CI->getZExtValue(); + uint64_t TruncToWidth = BitWidth - AShrAmt; + Type *TruncTy = IntegerType::get(getContext(), TruncToWidth); + Type *SExtTy = BO->LHS->getType(); + Value *LShOp1 = L->getOperand(1); + if (LShOp1 == BO->RHS) + // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV + // expression. return getSignExtendExpr( - getTruncateExpr(getSCEV(L->getOperand(0)), - IntegerType::get(getContext(), Amt)), - BO->LHS->getType()); - } + getTruncateExpr(LShOp0SCEV, TruncTy), SExtTy); + + // Handle below case: + // %y = shl %x, n + // %z = ashr %y, m + // where n != m + ConstantInt *LCI = dyn_cast(LShOp1); + if (!LCI || LCI->getValue().uge(BitWidth)) + break; + + uint64_t LShAmt = LCI->getZExtValue(); + // When n < m, we cannot use sext(udiv(trunc(x), 2^(m-n))) since + // udiv doesn't preserve sign. + if (LShAmt < AShrAmt) + break; + + // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV + // expression.We already checked that LShAmt < BitWidth, so + // the multiplier, 1 << (LShAmt - AShrAmt), fits into TruncTy as + // LShAmt - AShrAmt < Amt. + APInt Mul = APInt::getOneBitSet(64, LShAmt - AShrAmt); + const SCEV *MulSCEV = getSCEV( + ConstantInt::get(TruncTy, Mul.getZExtValue())); + return getSignExtendExpr( + getMulExpr(getTruncateExpr(LShOp0SCEV, TruncTy), MulSCEV), + SExtTy); + } + } break; } } Index: test/Analysis/ScalarEvolution/sext-mul.ll =================================================================== --- /dev/null +++ test/Analysis/ScalarEvolution/sext-mul.ll @@ -0,0 +1,46 @@ +; RUN: opt < %s -analyze -scalar-evolution | FileCheck %s + +; CHECK: %tmp10 = ashr exact i64 %tmp9, 32 +; CHECK: --> (sext i32 {0,+,2}<%bb7> to i64) U: [-2147483648,2147483648) S: [-2147483648,2147483647) Exits: (sext i32 (-2 + (2 * %arg2)) to i64) LoopDispositions: { %bb7: Computable } +; CHECK: %tmp11 = getelementptr inbounds i32, i32* %arg, i64 %tmp10 +; CHECK: --> ((4 * (sext i32 {0,+,2}<%bb7> to i64)) + %arg) U: full-set S: full-set Exits: ((4 * (sext i32 (-2 + (2 * %arg2)) to i64)) + %arg) LoopDispositions: { %bb7: Computable } +; CHECK: %tmp14 = or i64 %tmp10, 1 +; CHECK: --> (1 + (sext i32 {0,+,2}<%bb7> to i64)) U: [-2147483647,2147483649) S: [-2147483647,2147483648) Exits: (1 + (sext i32 (-2 + (2 * %arg2)) to i64)) LoopDispositions: { %bb7: Computable } +; CHECK: %tmp15 = getelementptr inbounds i32, i32* %arg, i64 %tmp14 +; CHECK: --> (4 + (4 * (sext i32 {0,+,2}<%bb7> to i64)) + %arg) U: full-set S: full-set Exits: (4 + (4 * (sext i32 (-2 + (2 * %arg2)) to i64)) + %arg) LoopDispositions: { %bb7: Computable } +; CHECK:Loop %bb7: backedge-taken count is (-1 + (zext i32 %arg2 to i64)) +; CHECK:Loop %bb7: max backedge-taken count is -1 +; CHECK:Loop %bb7: Predicated backedge-taken count is (-1 + (zext i32 %arg2 to i64)) + +define void @foo(i32* nocapture %arg, i32 %arg1, i32 %arg2) { +bb: + %tmp = icmp sgt i32 %arg2, 0 + br i1 %tmp, label %bb3, label %bb6 + +bb3: ; preds = %bb + %tmp4 = zext i32 %arg2 to i64 + br label %bb7 + +bb5: ; preds = %bb7 + br label %bb6 + +bb6: ; preds = %bb5, %bb + ret void + +bb7: ; preds = %bb7, %bb3 + %tmp8 = phi i64 [ %tmp18, %bb7 ], [ 0, %bb3 ] + %tmp9 = shl i64 %tmp8, 33 + %tmp10 = ashr exact i64 %tmp9, 32 + %tmp11 = getelementptr inbounds i32, i32* %arg, i64 %tmp10 + %tmp12 = load i32, i32* %tmp11, align 4 + %tmp13 = sub nsw i32 %tmp12, %arg1 + store i32 %tmp13, i32* %tmp11, align 4 + %tmp14 = or i64 %tmp10, 1 + %tmp15 = getelementptr inbounds i32, i32* %arg, i64 %tmp14 + %tmp16 = load i32, i32* %tmp15, align 4 + %tmp17 = mul nsw i32 %tmp16, %arg1 + store i32 %tmp17, i32* %tmp15, align 4 + %tmp18 = add nuw nsw i64 %tmp8, 1 + %tmp19 = icmp eq i64 %tmp18, %tmp4 + br i1 %tmp19, label %bb5, label %bb7 +}