diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -4293,7 +4293,9 @@ } // end anonymous namespace /// Try to map \p V into a BinaryOp, and return \c None on failure. -static Optional MatchBinaryOp(Value *V, DominatorTree &DT) { +static Optional MatchBinaryOp(Value *V, const DataLayout &DL, + AssumptionCache &AC, + DominatorTree &DT) { auto *Op = dyn_cast(V); if (!Op) return None; @@ -4340,6 +4342,21 @@ } return BinaryOp(Op); + case Instruction::SRem: { + // If the sign bits of both operands are zero (i.e. we can prove they are + // unsigned inputs), this is just an urem. + auto IsNonNegative = [&](Value *V) { + APInt Mask(APInt::getSignMask(Op->getType()->getScalarSizeInBits())); + return MaskedValueIsZero(V, Mask, DL, /*Depth=*/0, &AC, + dyn_cast(Op), &DT); + }; + if (IsNonNegative(Op->getOperand(1)) && IsNonNegative(Op->getOperand(0))) { + // X srem Y -> X urem Y, iff X and Y don't have sign bit set + return BinaryOp(Instruction::URem, Op->getOperand(0), Op->getOperand(1)); + } + break; + } + case Instruction::ExtractValue: { auto *EVI = cast(Op); if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0) @@ -4767,7 +4784,7 @@ assert(L && L->getHeader() == PN->getParent()); assert(BEValueV && StartValueV); - auto BO = MatchBinaryOp(BEValueV, DT); + auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT); if (!BO) return nullptr; @@ -4880,7 +4897,7 @@ cast(Accum)->getLoop() == L)) { SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; - if (auto BO = MatchBinaryOp(BEValueV, DT)) { + if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT)) { if (BO->Opcode == Instruction::Add && BO->LHS == PN) { if (BO->IsNUW) Flags = setFlags(Flags, SCEV::FlagNUW); @@ -5984,7 +6001,7 @@ return getUnknown(V); Operator *U = cast(V); - if (auto BO = MatchBinaryOp(U, DT)) { + if (auto BO = MatchBinaryOp(U, getDataLayout(), AC, DT)) { switch (BO->Opcode) { case Instruction::Add: { // The simple thing to do would be to just call getSCEV on both operands @@ -6025,7 +6042,7 @@ else AddOps.push_back(getSCEV(BO->RHS)); - auto NewBO = MatchBinaryOp(BO->LHS, DT); + auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT); if (!NewBO || (NewBO->Opcode != Instruction::Add && NewBO->Opcode != Instruction::Sub)) { AddOps.push_back(getSCEV(BO->LHS)); @@ -6055,7 +6072,7 @@ } MulOps.push_back(getSCEV(BO->RHS)); - auto NewBO = MatchBinaryOp(BO->LHS, DT); + auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT); if (!NewBO || NewBO->Opcode != Instruction::Mul) { MulOps.push_back(getSCEV(BO->LHS)); break; @@ -6280,7 +6297,7 @@ return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType()); case Instruction::SExt: - if (auto BO = MatchBinaryOp(U->getOperand(0), DT)) { + if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT)) { // The NSW flag of a subtract does not always survive the conversion to // A + (-1)*B. By pushing sign extension onto its operands we are much // more likely to preserve NSW and allow later AddRec optimisations. diff --git a/llvm/test/Analysis/ScalarEvolution/srem.ll b/llvm/test/Analysis/ScalarEvolution/srem.ll --- a/llvm/test/Analysis/ScalarEvolution/srem.ll +++ b/llvm/test/Analysis/ScalarEvolution/srem.ll @@ -14,11 +14,11 @@ ; CHECK-NEXT: %i.0 = phi i32 [ 0, %entry ], [ %inc, %for.body ] ; CHECK-NEXT: --> {0,+,1}<%for.cond> U: [0,-2147483648) S: [0,-2147483648) Exits: %width LoopDispositions: { %for.cond: Computable } ; CHECK-NEXT: %rem = srem i32 %i.0, 2 -; CHECK-NEXT: --> %rem U: [0,2) S: [-2,2) Exits: <> LoopDispositions: { %for.cond: Variant } +; CHECK-NEXT: --> (zext i1 {false,+,true}<%for.cond> to i32) U: [0,2) S: [0,2) Exits: (zext i1 (trunc i32 %width to i1) to i32) LoopDispositions: { %for.cond: Computable } ; CHECK-NEXT: %idxprom = sext i32 %rem to i64 -; CHECK-NEXT: --> (sext i32 %rem to i64) U: [0,2) S: [-2,2) Exits: <> LoopDispositions: { %for.cond: Variant } +; CHECK-NEXT: --> (zext i1 {false,+,true}<%for.cond> to i64) U: [0,2) S: [0,2) Exits: (zext i1 (trunc i32 %width to i1) to i64) LoopDispositions: { %for.cond: Computable } ; CHECK-NEXT: %arrayidx = getelementptr inbounds [2 x i32], [2 x i32]* %storage, i64 0, i64 %idxprom -; CHECK-NEXT: --> ((4 * (sext i32 %rem to i64)) + %storage) U: [0,-3) S: [-9223372036854775808,9223372036854775805) Exits: <> LoopDispositions: { %for.cond: Variant } +; CHECK-NEXT: --> ((4 * (zext i1 {false,+,true}<%for.cond> to i64)) + %storage) U: [0,-3) S: [-9223372036854775808,9223372036854775805) Exits: ((4 * (zext i1 (trunc i32 %width to i1) to i64)) + %storage) LoopDispositions: { %for.cond: Computable } ; CHECK-NEXT: %1 = load i32, i32* %arrayidx, align 4 ; CHECK-NEXT: --> %1 U: full-set S: full-set Exits: <> LoopDispositions: { %for.cond: Variant } ; CHECK-NEXT: %call = call i32 @_Z3adji(i32 %1)