Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -5250,20 +5250,22 @@ break; case Instruction::AShr: - // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. - if (ConstantInt *CI = dyn_cast(BO->RHS)) - if (Operator *L = dyn_cast(BO->LHS)) - if (L->getOpcode() == Instruction::Shl && - L->getOperand(1) == BO->RHS) { - uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType()); - - // If the shift count is not less than the bitwidth, the result of - // the shift is undefined. Don't try to analyze it, because the - // resolution chosen here may differ from the resolution chosen in - // other parts of the compiler. - if (CI->getValue().uge(BitWidth)) - break; + ConstantInt *CI = dyn_cast(BO->RHS); + Operator *L = dyn_cast(BO->LHS); + if (CI && L) { + if (L->getOpcode() == Instruction::Shl) { + uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType()); + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (CI->getValue().uge(BitWidth)) + break; + Value *LOp1 = L->getOperand(1); + + // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. + if (LOp1 == BO->RHS) { uint64_t Amt = BitWidth - CI->getZExtValue(); if (Amt == BitWidth) return getSCEV(L->getOperand(0)); // shift by zero --> noop @@ -5271,7 +5273,28 @@ getTruncateExpr(getSCEV(L->getOperand(0)), IntegerType::get(getContext(), Amt)), BO->LHS->getType()); + } else if (ConstantInt *LCI = dyn_cast(LOp1)) { + // %y = shl %x, n + // %z = ashr %y, m + // where n != m + if (LCI->getValue().uge(BitWidth)) + break; + + uint64_t AshrAmt = CI->getZExtValue(); + uint64_t LshAmt = LCI->getZExtValue(); + if (AshrAmt > LshAmt) + // n < m + //TODO: udiv(x, 2^(m-n)) + break; + // n > m + // Use mul(x, 2^(n-m))) as the SCEV expression. + uint64_t Mul = 1 << (LshAmt - AshrAmt); + Value *MulOp = ConstantInt::get(CI->getType(), Mul); + return getMulExpr(getSCEV(L->getOperand(0)), + getSCEV(MulOp)); } + } + } break; } } Index: test/Analysis/ScalarEvolution/sext-mul.ll =================================================================== --- /dev/null +++ test/Analysis/ScalarEvolution/sext-mul.ll @@ -0,0 +1,42 @@ +; RUN: opt < %s -analyze -scalar-evolution | FileCheck %s + +; CHECK: %12 = ashr exact i64 %11, 32 +; CHECK: --> {0,+,2}<%9> U: [0,-1) S: [-9223372036854775808,9223372036854775807) Exits: (-2 + (2 * (zext i32 %2 to i64))) LoopDispositions: { %9: Computable } +; CHECK: %13 = getelementptr inbounds i32, i32* %0, i64 %12 +; CHECK: --> {%0,+,8}<%9> U: full-set S: full-set Exits: (-8 + (8 * (zext i32 %2 to i64)) + %0) LoopDispositions: { %9: Computable } +; CHECK: %16 = or i64 %12, 1 +; CHECK: --> {1,+,2}<%9> U: full-set S: full-set Exits: (-1 + (2 * (zext i32 %2 to i64))) LoopDispositions: { %9: Computable } +; CHECK: %17 = getelementptr inbounds i32, i32* %0, i64 %16 +; CHECK: --> {(4 + %0),+,8}<%9> U: full-set S: full-set Exits: (-4 + (8 * (zext i32 %2 to i64)) + %0) LoopDispositions: { %9: Computable } + +define void @foo(i32* nocapture, i32, i32) { + %4 = icmp sgt i32 %2, 0 + br i1 %4, label %5, label %8 + +;