diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -6453,15 +6453,19 @@ if (SA->getValue().uge(BitWidth)) break; - // It is currently not resolved how to interpret NSW for left - // shift by BitWidth - 1, so we avoid applying flags in that - // case. Remove this check (or this comment) once the situation - // is resolved. See - // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html - // and http://reviews.llvm.org/D8890 . + // We can safely preserve the nuw flag in all cases. It's also safe to + // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation + // requires special handling. It can be preserved as long as we're not + // left shifting by bitwidth - 1. auto Flags = SCEV::FlagAnyWrap; - if (BO->Op && SA->getValue().ult(BitWidth - 1)) - Flags = getNoWrapFlagsFromUB(BO->Op); + if (BO->Op) { + auto MulFlags = getNoWrapFlagsFromUB(BO->Op); + if ((MulFlags & SCEV::FlagNSW) && + ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1))) + Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNSW); + if (MulFlags & SCEV::FlagNUW) + Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNUW); + } Constant *X = ConstantInt::get( getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue())); diff --git a/llvm/test/Analysis/ScalarEvolution/flags-from-poison.ll b/llvm/test/Analysis/ScalarEvolution/flags-from-poison.ll --- a/llvm/test/Analysis/ScalarEvolution/flags-from-poison.ll +++ b/llvm/test/Analysis/ScalarEvolution/flags-from-poison.ll @@ -559,11 +559,11 @@ %i = phi i32 [ %nexti, %loop ], [ %start, %entry ] ; CHECK: %index32 = -; CHECK: --> {(-2147483648 * %start),+,-2147483648}<%loop> +; CHECK: --> {(-2147483648 * %start),+,-2147483648}<%loop> %index32 = shl nuw nsw i32 %i, 31 ; CHECK: %index64 = -; CHECK: --> (sext i32 {(-2147483648 * %start),+,-2147483648}<%loop> +; CHECK: --> {(sext i32 (-2147483648 * %start) to i64),+,-2147483648}<%loop> %index64 = sext i32 %index32 to i64 %ptr = getelementptr inbounds float, float* %input, i64 %index64