diff --git a/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/llvm/lib/Analysis/BasicAliasAnalysis.cpp --- a/llvm/lib/Analysis/BasicAliasAnalysis.cpp +++ b/llvm/lib/Analysis/BasicAliasAnalysis.cpp @@ -360,8 +360,10 @@ } LinearExpression mul(const APInt &Other, bool MulIsNSW) const { - return LinearExpression(Val, Scale * Other, Offset * Other, - IsNSW && (Other.isOne() || MulIsNSW)); + // The check for zero offset is necessary, because generally + // (X +nsw Y) *nsw Z does not imply (X *nsw Z) +nsw (Y *nsw Z). + bool NSW = IsNSW && (Other.isOne() || (MulIsNSW && Offset.isZero())); + return LinearExpression(Val, Scale * Other, Offset * Other, NSW); } }; } @@ -1249,12 +1251,14 @@ CR = CR.intersectWith( ConstantRange::fromKnownBits(Known, /* Signed */ true), ConstantRange::Signed); + CR = Index.Val.evaluateWith(CR).sextOrTrunc(OffsetRange.getBitWidth()); assert(OffsetRange.getBitWidth() == Scale.getBitWidth() && "Bit widths are normalized to MaxPointerSize"); - OffsetRange = OffsetRange.add( - Index.Val.evaluateWith(CR).sextOrTrunc(OffsetRange.getBitWidth()) - .smul_fast(ConstantRange(Scale))); + if (Index.IsNSW) + OffsetRange = OffsetRange.add(CR.smul_sat(ConstantRange(Scale))); + else + OffsetRange = OffsetRange.add(CR.smul_fast(ConstantRange(Scale))); } // We now have accesses at two offsets from the same base: diff --git a/llvm/test/Analysis/BasicAA/assume-index-positive.ll b/llvm/test/Analysis/BasicAA/assume-index-positive.ll --- a/llvm/test/Analysis/BasicAA/assume-index-positive.ll +++ b/llvm/test/Analysis/BasicAA/assume-index-positive.ll @@ -145,12 +145,12 @@ ret void } -; TODO: Unlike the previous case, %ptr.neg and %ptr.shl can't alias, because +; Unlike the previous case, %ptr.neg and %ptr.shl can't alias, because ; shl nsw of non-negative is non-negative. define void @shl_nsw_of_non_negative(i8* %ptr, i64 %a) { ; CHECK-LABEL: Function: shl_nsw_of_non_negative ; CHECK: NoAlias: i8* %ptr.a, i8* %ptr.neg -; CHECK: MayAlias: i8* %ptr.neg, i8* %ptr.shl +; CHECK: NoAlias: i8* %ptr.neg, i8* %ptr.shl %a.cmp = icmp sge i64 %a, 0 call void @llvm.assume(i1 %a.cmp) %ptr.neg = getelementptr i8, i8* %ptr, i64 -2