diff --git a/llvm/include/llvm/Analysis/BasicAliasAnalysis.h b/llvm/include/llvm/Analysis/BasicAliasAnalysis.h --- a/llvm/include/llvm/Analysis/BasicAliasAnalysis.h +++ b/llvm/include/llvm/Analysis/BasicAliasAnalysis.h @@ -116,6 +116,9 @@ // Context instruction to use when querying information about this index. const Instruction *CxtI; + /// True if all operations in this expression are NSW. + bool IsNSW; + void dump() const { print(dbgs()); dbgs() << "\n"; diff --git a/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/llvm/lib/Analysis/BasicAliasAnalysis.cpp --- a/llvm/lib/Analysis/BasicAliasAnalysis.cpp +++ b/llvm/lib/Analysis/BasicAliasAnalysis.cpp @@ -284,11 +284,14 @@ APInt Scale; APInt Offset; + /// True if all operations in this expression are NSW. + bool IsNSW; + LinearExpression(const ExtendedValue &Val, const APInt &Scale, - const APInt &Offset) - : Val(Val), Scale(Scale), Offset(Offset) {} + const APInt &Offset, bool IsNSW) + : Val(Val), Scale(Scale), Offset(Offset), IsNSW(IsNSW) {} - LinearExpression(const ExtendedValue &Val) : Val(Val) { + LinearExpression(const ExtendedValue &Val) : Val(Val), IsNSW(true) { unsigned BitWidth = Val.getBitWidth(); Scale = APInt(BitWidth, 1); Offset = APInt(BitWidth, 0); @@ -307,7 +310,7 @@ if (const ConstantInt *Const = dyn_cast(Val.V)) return LinearExpression(Val, APInt(Val.getBitWidth(), 0), - Val.evaluateWith(Const->getValue())); + Val.evaluateWith(Const->getValue()), true); if (const BinaryOperator *BOp = dyn_cast(Val.V)) { if (ConstantInt *RHSC = dyn_cast(BOp->getOperand(1))) { @@ -322,6 +325,7 @@ if (!Val.canDistributeOver(NUW, NSW)) return Val; + LinearExpression E(Val); switch (BOp->getOpcode()) { default: // We don't understand this instruction, so we can't decompose it any @@ -336,23 +340,26 @@ LLVM_FALLTHROUGH; case Instruction::Add: { - LinearExpression E = GetLinearExpression( - Val.withValue(BOp->getOperand(0)), DL, Depth + 1, AC, DT); + E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL, + Depth + 1, AC, DT); E.Offset += RHS; - return E; + E.IsNSW &= NSW; + break; } case Instruction::Sub: { - LinearExpression E = GetLinearExpression( - Val.withValue(BOp->getOperand(0)), DL, Depth + 1, AC, DT); + E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL, + Depth + 1, AC, DT); E.Offset -= RHS; - return E; + E.IsNSW &= NSW; + break; } case Instruction::Mul: { - LinearExpression E = GetLinearExpression( - Val.withValue(BOp->getOperand(0)), DL, Depth + 1, AC, DT); + E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL, + Depth + 1, AC, DT); E.Offset *= RHS; E.Scale *= RHS; - return E; + E.IsNSW &= NSW; + break; } case Instruction::Shl: // We're trying to linearize an expression of the kind: @@ -363,12 +370,14 @@ if (RHS.getLimitedValue() > Val.getBitWidth()) return Val; - LinearExpression E = GetLinearExpression( - Val.withValue(BOp->getOperand(0)), DL, Depth + 1, AC, DT); + E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL, + Depth + 1, AC, DT); E.Offset <<= RHS.getLimitedValue(); E.Scale <<= RHS.getLimitedValue(); - return E; + E.IsNSW &= NSW; + break; } + return E; } } @@ -578,8 +587,8 @@ Scale = adjustToPointerSize(Scale, PointerSize); if (!!Scale) { - VariableGEPIndex Entry = {LE.Val.V, LE.Val.ZExtBits, LE.Val.SExtBits, - Scale, CxtI}; + VariableGEPIndex Entry = { + LE.Val.V, LE.Val.ZExtBits, LE.Val.SExtBits, Scale, CxtI, LE.IsNSW}; Decomposed.VarIndices.push_back(Entry); } } @@ -1138,7 +1147,11 @@ bool AllNonNegative = DecompGEP1.Offset.isNonNegative(); bool AllNonPositive = DecompGEP1.Offset.isNonPositive(); for (unsigned i = 0, e = DecompGEP1.VarIndices.size(); i != e; ++i) { - const APInt &Scale = DecompGEP1.VarIndices[i].Scale; + APInt Scale = DecompGEP1.VarIndices[i].Scale; + if (!DecompGEP1.VarIndices[i].IsNSW) + Scale = APInt::getOneBitSet(Scale.getBitWidth(), + Scale.countTrailingZeros()); + if (i == 0) GCD = Scale.abs(); else @@ -1701,9 +1714,10 @@ // If we found it, subtract off Scale V's from the entry in Dest. If it // goes to zero, remove the entry. - if (Dest[j].Scale != Scale) + if (Dest[j].Scale != Scale) { Dest[j].Scale -= Scale; - else + Dest[j].IsNSW = false; + } else Dest.erase(Dest.begin() + j); Scale = 0; break; @@ -1711,7 +1725,8 @@ // If we didn't consume this entry, add it to the end of the Dest list. if (!!Scale) { - VariableGEPIndex Entry = {V, ZExtBits, SExtBits, -Scale, Src[i].CxtI}; + VariableGEPIndex Entry = {V, ZExtBits, SExtBits, + -Scale, Src[i].CxtI, Src[i].IsNSW}; Dest.push_back(Entry); } } diff --git a/llvm/test/Analysis/BasicAA/gep-modulo.ll b/llvm/test/Analysis/BasicAA/gep-modulo.ll --- a/llvm/test/Analysis/BasicAA/gep-modulo.ll +++ b/llvm/test/Analysis/BasicAA/gep-modulo.ll @@ -70,7 +70,7 @@ ; CHECK-LABEL: Function: may_overflow_mul_sub_i64: 3 pointers, 0 call sites ; CHECK-NEXT: MayAlias: [16 x i8]* %ptr, i8* %gep.idx ; CHECK-NEXT: PartialAlias (off 3): [16 x i8]* %ptr, i8* %gep.3 -; CHECK-NEXT: NoAlias: i8* %gep.3, i8* %gep.idx +; CHECK-NEXT: MayAlias: i8* %gep.3, i8* %gep.idx ; %mul = mul i64 %idx, 5 %sub = sub i64 %mul, 1 @@ -115,7 +115,7 @@ ; CHECK-LABEL: Function: only_nuw_mul_sub_i64: 3 pointers, 0 call sites ; CHECK-NEXT: MayAlias: [16 x i8]* %ptr, i8* %gep.idx ; CHECK-NEXT: PartialAlias (off 3): [16 x i8]* %ptr, i8* %gep.3 -; CHECK-NEXT: NoAlias: i8* %gep.3, i8* %gep.idx +; CHECK-NEXT: MayAlias: i8* %gep.3, i8* %gep.idx ; %mul = mul nuw i64 %idx, 5 %sub = sub nuw i64 %mul, 1 @@ -126,6 +126,8 @@ ret void } +; Even though the mul and sub may overflow %gep.idx and %gep.3 cannot alias +; because we multiply by a power-of-2. define void @may_overflow_mul_pow2_sub_i64([16 x i8]* %ptr, i64 %idx) { ; CHECK-LABEL: Function: may_overflow_mul_pow2_sub_i64: 3 pointers, 0 call sites ; CHECK-NEXT: MayAlias: [16 x i8]* %ptr, i8* %gep.idx @@ -259,7 +261,7 @@ ; CHECK-LABEL: Function: may_overflow_pointer_diff: 3 pointers, 0 call sites ; CHECK-NEXT: MayAlias: [16 x i8]* %ptr, i8* %gep.mul.1 ; CHECK-NEXT: MayAlias: [16 x i8]* %ptr, i8* %gep.sub.2 -; CHECK-NEXT: NoAlias: i8* %gep.mul.1, i8* %gep.sub.2 +; CHECK-NEXT: MayAlias: i8* %gep.mul.1, i8* %gep.sub.2 ; %mul.1 = mul i64 %idx, 6148914691236517207 %gep.mul.1 = getelementptr [16 x i8], [16 x i8]* %ptr, i32 0, i64 %mul.1