Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -712,7 +712,8 @@ /// getNegativeSCEV - Return the SCEV object corresponding to -V. /// - const SCEV *getNegativeSCEV(const SCEV *V); + const SCEV *getNegativeSCEV(const SCEV *V, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap); /// getNotSCEV - Return the SCEV object corresponding to ~V. /// Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -3339,15 +3339,16 @@ /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V /// -const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V) { +const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, + SCEV::NoWrapFlags Flags) { if (const SCEVConstant *VC = dyn_cast(V)) return getConstant( cast(ConstantExpr::getNeg(VC->getValue()))); Type *Ty = V->getType(); Ty = getEffectiveSCEVType(Ty); - return getMulExpr(V, - getConstant(cast(Constant::getAllOnesValue(Ty)))); + return getMulExpr( + V, getConstant(cast(Constant::getAllOnesValue(Ty))), Flags); } /// getNotSCEV - Return a SCEV corresponding to ~V = -1-V @@ -3366,15 +3367,40 @@ /// getMinusSCEV - Return LHS-RHS. Minus is represented in SCEV as A+B*-1. const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags) { - assert(!maskFlags(Flags, SCEV::FlagNUW) && "subtraction does not have NUW"); - // Fast path: X - X --> 0. if (LHS == RHS) return getConstant(LHS->getType(), 0); - // X - Y --> X + -Y. - // X -(nsw || nuw) Y --> X + -Y. - return getAddExpr(LHS, getNegativeSCEV(RHS)); + auto AddFlags = SCEV::FlagAnyWrap; + auto NegFlags = SCEV::FlagAnyWrap; + + // We will transform LHS - RHS to LHS + (-RHS). Thus we cannot make + // use of NUW, since -RHS will unsigned-wrap for any non-zero value. + if (maskFlags(Flags, SCEV::FlagNSW) == SCEV::FlagNSW) { + // Let M be the minimum representable signed value. Then -RHS + // signed-wraps if and only if RHS is M. That can happen even for + // a NSW subtraction because e.g. -M signed-wraps even though -1 - M + // does not. So to transfer NSW from LHS - RHS to LHS + (-RHS), we + // need to prove that RHS != M. + // + // If LHS is non-negative and we know that LHS - RHS does not + // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap + // either by proving that RHS > M or that LHS >= 0. + if (isKnownNonNegative(LHS) || + !getSignedRange(RHS).getSignedMin().isMinSignedValue()) { + AddFlags = SCEV::FlagNSW; + + // We need to handle the situation where the LHS supplies the + // recurrence that has the loop that the flags are relative to + // while the RHS has no recurrence. If RHS has no recurrence and + // we attach NSW to the SCEV for -RHS, then we are stating that + // -RHS never wraps anywhere, but we only know that -RHS does + // not wrap within the relevant loop. + if (isa(RHS)) NegFlags = SCEV::FlagNSW; + } + } + + return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags); } /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the @@ -4094,7 +4120,8 @@ } SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { - const BinaryOperator *BinOp = cast(V); + const BinaryOperator *BinOp = dyn_cast(V); + if (!BinOp) return SCEV::FlagAnyWrap; // Return early if there are no flags to propagate to the SCEV. SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; @@ -4185,9 +4212,6 @@ // because it leads to N-1 getAddExpr calls for N ultimate operands. // Instead, gather up all the operands and make a single getAddExpr call. // LLVM IR canonical form means we need only traverse the left operands. - // - // FIXME: Expand this handling of NSW and NUW to other instructions, like - // sub and mul. SmallVector AddOps; for (Value *Op = U;; Op = U->getOperand(0)) { U = dyn_cast(Op); @@ -4210,45 +4234,56 @@ // since the flags are only known to apply to this particular // addition - they may not apply to other additions that can be // formed with operands from AddOps. - // - // FIXME: Expand this to sub instructions. - if (Opcode == Instruction::Add && isa(U)) { - SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U); - if (Flags != SCEV::FlagAnyWrap) { - AddOps.push_back(getAddExpr(getSCEV(U->getOperand(0)), - getSCEV(U->getOperand(1)), Flags)); - break; - } + const SCEV *RHS = getSCEV(U->getOperand(1)); + SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U); + if (Flags != SCEV::FlagAnyWrap) { + if (Opcode == Instruction::Sub) + AddOps.push_back(getMinusSCEV(getSCEV(U->getOperand(0)), RHS, Flags)); + else + AddOps.push_back(getAddExpr(getSCEV(U->getOperand(0)), RHS, Flags)); + break; } - const SCEV *Op1 = getSCEV(U->getOperand(1)); if (Opcode == Instruction::Sub) - AddOps.push_back(getNegativeSCEV(Op1)); + AddOps.push_back(getNegativeSCEV(RHS)); else - AddOps.push_back(Op1); + AddOps.push_back(RHS); } return getAddExpr(AddOps); } case Instruction::Mul: { - // FIXME: Transfer NSW/NUW as in AddExpr. SmallVector MulOps; - MulOps.push_back(getSCEV(U->getOperand(1))); - for (Value *Op = U->getOperand(0); - Op->getValueID() == Instruction::Mul + Value::InstructionVal; - Op = U->getOperand(0)) { - U = cast(Op); + for (Value *Op = U;; Op = U->getOperand(0)) { + U = dyn_cast(Op); + if (!U || U->getOpcode() != Instruction::Mul) { + assert(Op != V && "V should be a mul"); + MulOps.push_back(getSCEV(Op)); + break; + } + + if (auto *OpSCEV = getExistingSCEV(Op)) { + MulOps.push_back(OpSCEV); + break; + } + + SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U); + if (Flags != SCEV::FlagAnyWrap) { + MulOps.push_back(getMulExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1)), Flags)); + break; + } + MulOps.push_back(getSCEV(U->getOperand(1))); } - MulOps.push_back(getSCEV(U->getOperand(0))); return getMulExpr(MulOps); } case Instruction::UDiv: return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1))); case Instruction::Sub: - return getMinusSCEV(getSCEV(U->getOperand(0)), - getSCEV(U->getOperand(1))); + return getMinusSCEV(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)), + getNoWrapFlagsFromUB(U)); case Instruction::And: // For an expression like x&255 that merely masks off the high bits, // use zext(trunc(x)) as the SCEV expression. @@ -4370,7 +4405,8 @@ Constant *X = ConstantInt::get(getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue())); - return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X)); + return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X), + getNoWrapFlagsFromUB(U)); } break; Index: test/Analysis/Delinearization/a.ll =================================================================== --- test/Analysis/Delinearization/a.ll +++ test/Analysis/Delinearization/a.ll @@ -10,7 +10,7 @@ ; AddRec: {{{(28 + (4 * (-4 + (3 * %m)) * %o) + %A),+,(8 * %m * %o)}<%for.i>,+,(12 * %o)}<%for.j>,+,20}<%for.k> ; CHECK: Base offset: %A ; CHECK: ArrayDecl[UnknownSize][%m][%o] with elements of 4 bytes. -; CHECK: ArrayRef[{3,+,2}<%for.i>][{-4,+,3}<%for.j>][{7,+,5}<%for.k>] +; CHECK: ArrayRef[{3,+,2}<%for.i>][{-4,+,3}<%for.j>][{7,+,5}<%for.k>] define void @foo(i64 %n, i64 %m, i64 %o, i32* nocapture %A) #0 { entry: Index: test/Analysis/ScalarEvolution/flags-from-poison.ll =================================================================== --- test/Analysis/ScalarEvolution/flags-from-poison.ll +++ test/Analysis/ScalarEvolution/flags-from-poison.ll @@ -356,3 +356,204 @@ exit: ret void } + +; Example where a mul should get the nsw flag, so that a sext can be +; distributed over the mul. +define void @test-mul-nsw(float* %input, i32 %stride, i32 %numIterations) { +; CHECK-LABEL: @test-mul-nsw +entry: + br label %loop +loop: + %i = phi i32 [ %nexti, %loop ], [ 0, %entry ] + +; CHECK: %index32 = +; CHECK: --> {0,+,%stride} + %index32 = mul nsw i32 %i, %stride + +; CHECK: %index64 = +; CHECK: --> {0,+,(sext i32 %stride to i64)} + %index64 = sext i32 %index32 to i64 + + %ptr = getelementptr inbounds float, float* %input, i64 %index64 + %nexti = add nsw i32 %i, 1 + %f = load float, float* %ptr, align 4 + %exitcond = icmp eq i32 %nexti, %numIterations + br i1 %exitcond, label %exit, label %loop +exit: + ret void +} + +; Example where a mul should get the nuw flag. +define void @test-mul-nuw(float* %input, i32 %stride, i32 %numIterations) { +; CHECK-LABEL: @test-mul-nuw +entry: + br label %loop +loop: + %i = phi i32 [ %nexti, %loop ], [ 0, %entry ] + +; CHECK: %index32 = +; CHECK: --> {0,+,%stride} + %index32 = mul nuw i32 %i, %stride + + %ptr = getelementptr inbounds float, float* %input, i32 %index32 + %nexti = add nuw i32 %i, 1 + %f = load float, float* %ptr, align 4 + %exitcond = icmp eq i32 %nexti, %numIterations + br i1 %exitcond, label %exit, label %loop + +exit: + ret void +} + +; Example where a shl should get the nsw flag, so that a sext can be +; distributed over the shl. +define void @test-shl-nsw(float* %input, i32 %start, i32 %numIterations) { +; CHECK-LABEL: @test-shl-nsw +entry: + br label %loop +loop: + %i = phi i32 [ %nexti, %loop ], [ %start, %entry ] + +; CHECK: %index32 = +; CHECK: --> {(256 * %start),+,256} + %index32 = shl nsw i32 %i, 8 + +; CHECK: %index64 = +; CHECK: --> {(sext i32 (256 * %start) to i64),+,256} + %index64 = sext i32 %index32 to i64 + + %ptr = getelementptr inbounds float, float* %input, i64 %index64 + %nexti = add nsw i32 %i, 1 + %f = load float, float* %ptr, align 4 + %exitcond = icmp eq i32 %nexti, %numIterations + br i1 %exitcond, label %exit, label %loop +exit: + ret void +} + +; Example where a shl should get the nuw flag. +define void @test-shl-nuw(float* %input, i32 %numIterations) { +; CHECK-LABEL: @test-shl-nuw +entry: + br label %loop +loop: + %i = phi i32 [ %nexti, %loop ], [ 0, %entry ] + +; CHECK: %index32 = +; CHECK: --> {0,+,512} + %index32 = shl nuw i32 %i, 9 + + %ptr = getelementptr inbounds float, float* %input, i32 %index32 + %nexti = add nuw i32 %i, 1 + %f = load float, float* %ptr, align 4 + %exitcond = icmp eq i32 %nexti, %numIterations + br i1 %exitcond, label %exit, label %loop + +exit: + ret void +} + +; Example where a sub should *not* get the nsw flag, because of how +; scalar evolution represents A - B as A + (-B) and -B can wrap even +; in cases where A - B does not. +define void @test-sub-no-nsw(float* %input, i32 %start, i32 %sub, i32 %numIterations) { +; CHECK-LABEL: @test-sub-no-nsw +entry: + br label %loop +loop: + %i = phi i32 [ %nexti, %loop ], [ %start, %entry ] + +; CHECK: %index32 = +; CHECK: --> {((-1 * %sub) + %start),+,1} + %index32 = sub nsw i32 %i, %sub + %index64 = sext i32 %index32 to i64 + + %ptr = getelementptr inbounds float, float* %input, i64 %index64 + %nexti = add nsw i32 %i, 1 + %f = load float, float* %ptr, align 4 + %exitcond = icmp eq i32 %nexti, %numIterations + br i1 %exitcond, label %exit, label %loop +exit: + ret void +} + +; Example where a sub should get the nsw flag as the RHS cannot be the +; minimal signed value. +define void @test-sub-nsw(float* %input, i32 %start, i32 %sub, i32 %numIterations) { +; CHECK-LABEL: @test-sub-nsw +entry: + %halfsub = ashr i32 %sub, 1 + br label %loop +loop: + %i = phi i32 [ %nexti, %loop ], [ %start, %entry ] + +; CHECK: %index32 = +; CHECK: --> {((-1 * %halfsub) + %start),+,1} + %index32 = sub nsw i32 %i, %halfsub + %index64 = sext i32 %index32 to i64 + + %ptr = getelementptr inbounds float, float* %input, i64 %index64 + %nexti = add nsw i32 %i, 1 + %f = load float, float* %ptr, align 4 + %exitcond = icmp eq i32 %nexti, %numIterations + br i1 %exitcond, label %exit, label %loop +exit: + ret void +} + +; Example where a sub should get the nsw flag, since the LHS is non-negative, +; which implies that the RHS cannot be the minimal signed value. +define void @test-sub-nsw-lhs-non-negative(float* %input, i32 %sub, i32 %numIterations) { +; CHECK-LABEL: @test-sub-nsw-lhs-non-negative +entry: + br label %loop +loop: + %i = phi i32 [ %nexti, %loop ], [ 0, %entry ] + +; CHECK: %index32 = +; CHECK: --> {(-1 * %sub),+,1} + %index32 = sub nsw i32 %i, %sub + +; CHECK: %index64 = +; CHECK: --> {(sext i32 (-1 * %sub) to i64),+,1} + %index64 = sext i32 %index32 to i64 + + %ptr = getelementptr inbounds float, float* %input, i64 %index64 + %nexti = add nsw i32 %i, 1 + %f = load float, float* %ptr, align 4 + %exitcond = icmp eq i32 %nexti, %numIterations + br i1 %exitcond, label %exit, label %loop +exit: + ret void +} + +; Two adds with a sub in the middle and the sub should have nsw. There is +; a special case for sequential adds/subs and this test covers that. We have to +; put the final add first in the program since otherwise the special case +; is not triggered, hence the strange basic block ordering. +define void @test-sub-with-add(float* %input, i32 %offset, i32 %numIterations) { +; CHECK-LABEL: @test-sub-with-add +entry: + br label %loop +loop2: +; CHECK: %seq = +; CHECK: --> {(2 + (-1 * %offset)),+,1} + %seq = add nsw nuw i32 %index32, 1 + %exitcond = icmp eq i32 %nexti, %numIterations + br i1 %exitcond, label %exit, label %loop + +loop: + %i = phi i32 [ %nexti, %loop2 ], [ 0, %entry ] + + %j = add nsw i32 %i, 1 +; CHECK: %index32 = +; CHECK: --> {(1 + (-1 * %offset)),+,1} + %index32 = sub nsw i32 %j, %offset + + %ptr = getelementptr inbounds float, float* %input, i32 %index32 + %nexti = add nsw i32 %i, 1 + store float 1.0, float* %ptr, align 4 + br label %loop2 +exit: + ret void +} Index: test/Transforms/LoopStrengthReduce/sext-ind-var.ll =================================================================== --- test/Transforms/LoopStrengthReduce/sext-ind-var.ll +++ test/Transforms/LoopStrengthReduce/sext-ind-var.ll @@ -8,6 +8,11 @@ ; instruction to the SCEV, preventing distributing sext into the ; corresponding addrec. +; Test this pattern: +; +; for (int i = 0; i < numIterations; ++i) +; sum += ptr[i + offset]; +; define float @testadd(float* %input, i32 %offset, i32 %numIterations) { ; CHECK-LABEL: @testadd ; CHECK: sext i32 %offset to i64 @@ -34,3 +39,102 @@ exit: ret float %nextsum } + +; Test this pattern: +; +; for (int i = 0; i < numIterations; ++i) +; sum += ptr[i - offset]; +; +define float @testsub(float* %input, i32 %offset, i32 %numIterations) { +; CHECK-LABEL: @testsub +; CHECK: sub i32 0, %offset +; CHECK: sext i32 +; CHECK: loop: +; CHECK-DAG: phi float* +; CHECK-DAG: phi i32 +; CHECK-NOT: sext + +entry: + br label %loop + +loop: + %i = phi i32 [ %nexti, %loop ], [ 0, %entry ] + %sum = phi float [ %nextsum, %loop ], [ 0.000000e+00, %entry ] + %index32 = sub nuw nsw i32 %i, %offset + %index64 = sext i32 %index32 to i64 + %ptr = getelementptr inbounds float, float* %input, i64 %index64 + %addend = load float, float* %ptr, align 4 + %nextsum = fadd float %sum, %addend + %nexti = add nuw nsw i32 %i, 1 + %exitcond = icmp eq i32 %nexti, %numIterations + br i1 %exitcond, label %exit, label %loop + +exit: + ret float %nextsum +} + +; Test this pattern: +; +; for (int i = 0; i < numIterations; ++i) +; sum += ptr[i * stride]; +; +define float @testmul(float* %input, i32 %stride, i32 %numIterations) { +; CHECK-LABEL: @testmul +; CHECK: sext i32 %stride to i64 +; CHECK: loop: +; CHECK-DAG: phi float* +; CHECK-DAG: phi i32 +; CHECK-NOT: sext + +entry: + br label %loop + +loop: + %i = phi i32 [ %nexti, %loop ], [ 0, %entry ] + %sum = phi float [ %nextsum, %loop ], [ 0.000000e+00, %entry ] + %index32 = mul nuw nsw i32 %i, %stride + %index64 = sext i32 %index32 to i64 + %ptr = getelementptr inbounds float, float* %input, i64 %index64 + %addend = load float, float* %ptr, align 4 + %nextsum = fadd float %sum, %addend + %nexti = add nuw nsw i32 %i, 1 + %exitcond = icmp eq i32 %nexti, %numIterations + br i1 %exitcond, label %exit, label %loop + +exit: + ret float %nextsum +} + +; Test this pattern: +; +; for (int i = 0; i < numIterations; ++i) +; sum += ptr[3 * (i << 7)]; +; +; The multiplication by 3 is to make the address calculation expensive +; enough to force the introduction of a pointer induction variable. +define float @testshl(float* %input, i32 %numIterations) { +; CHECK-LABEL: @testshl +; CHECK: loop: +; CHECK-DAG: phi float* +; CHECK-DAG: phi i32 +; CHECK-NOT: sext + +entry: + br label %loop + +loop: + %i = phi i32 [ %nexti, %loop ], [ 0, %entry ] + %sum = phi float [ %nextsum, %loop ], [ 0.000000e+00, %entry ] + %index32 = shl nuw nsw i32 %i, 7 + %index32mul = mul nuw nsw i32 %index32, 3 + %index64 = sext i32 %index32mul to i64 + %ptr = getelementptr inbounds float, float* %input, i64 %index64 + %addend = load float, float* %ptr, align 4 + %nextsum = fadd float %sum, %addend + %nexti = add nuw nsw i32 %i, 1 + %exitcond = icmp eq i32 %nexti, %numIterations + br i1 %exitcond, label %exit, label %loop + +exit: + ret float %nextsum +}