Index: lib/Transforms/Scalar/LoopPredication.cpp =================================================================== --- lib/Transforms/Scalar/LoopPredication.cpp +++ lib/Transforms/Scalar/LoopPredication.cpp @@ -58,7 +58,7 @@ // I = PHI(Start, I.INC) // I.INC = I + Step // guard(G(I)); -// } while (B(I.INC)); +// } while (B(I)); // } // // where B(x) and G(x) are predicates that map integers to booleans, we want a @@ -70,10 +70,10 @@ // I = PHI(Start, I.INC) // I.INC = I + Step // guard(G(Start) && M); -// } while (B(I.INC)); +// } while (B(I)); // } // -// One solution for M is M = forall X . (G(X) && B(X + Step)) => G(X + Step) +// One solution for M is M = forall X . (G(X) && B(X)) => G(X + Step) // // Informal proof that the transformation above is correct: // @@ -90,64 +90,66 @@ // Induction step. Assuming G(Start) && M => G(I) on the subsequent // iteration: // -// B(I + Step) is true because it's the backedge condition. +// B(I) is true because it's the backedge condition. // G(I) is true because the backedge is guarded by this condition. // -// So M = forall X . (G(X) && B(X + Step)) => G(X + Step) implies -// G(I + Step). +// So M = forall X . (G(X) && B(X)) => G(X + Step) implies G(I + Step). // // Note that we can use anything stronger than M, i.e. any condition which // implies M. // // For now the transformation is limited to the following case: // * The loop has a single latch with the condition of the form: -// ++i latchLimit, where is u<, u<=, s<, or s<=. +// latchStart + i latchLimit, where is u<, u<=, s<, or s<=. // * The step of the IV used in the latch condition is 1. -// * The IV of the latch condition is the same as the post increment IV of the -// guard condition. -// * The guard condition is -// i u< guardLimit. +// * The guard condition is of the form +// guardStart + i u< guardLimit // // For the ult latch comparison case M is: -// forall X . X u< guardLimit && (X + 1) u< latchLimit => -// (X + 1) u< guardLimit +// forall X . guardStart + X u< guardLimit && latchStart + X +// guardStart + X + 1 u< guardLimit // -// This is true if latchLimit u<= guardLimit since then -// (X + 1) u< latchLimit u<= guardLimit == (X + 1) u< guardLimit. -// -// So for ult condition the widened condition is: -// i.start u< guardLimit && latchLimit u<= guardLimit -// Similarly for ule condition the widened condition is: -// i.start u< guardLimit && latchLimit u< guardLimit -// -// For the signed latch comparison case M is: -// forall X . X u< guardLimit && (X + 1) s< latchLimit => -// (X + 1) u< guardLimit -// -// The only way the antecedent can be true and the consequent can be false is +// The only way the antecedent can be true and the consequent can be false is // if -// X == guardLimit - 1 +// X == guardLimit - 1 - guardStart // (and guardLimit is non-zero, but we won't use this latter fact). -// If X == guardLimit - 1 then the second half of the antecedent is -// guardLimit s< latchLimit +// If X == guardLimit - 1 - guardStart then the second half of the antecedent is +// latchStart + guardLimit - 1 - guardStart u< latchLimit // and its negation is -// latchLimit s<= guardLimit. +// latchStart + guardLimit - 1 - guardStart u>= latchLimit // -// In other words, if latchLimit s<= guardLimit then: +// In other words, if +// latchLimit u<= latchStart + guardLimit - 1 - guardStart +// then: // (the ranges below are written in ConstantRange notation, where [A, B) is the // set for (I = A; I != B; I++ /*maywrap*/) yield(I);) // -// forall X . X u< guardLimit && (X + 1) s< latchLimit => (X + 1) u< guardLimit -// == forall X . X u< guardLimit && (X + 1) s< guardLimit => (X + 1) u< guardLimit -// == forall X . X in [0, guardLimit) && (X + 1) in [INT_MIN, guardLimit) => (X + 1) in [0, guardLimit) -// == forall X . X in [0, guardLimit) && X in [INT_MAX, guardLimit-1) => X in [-1, guardLimit-1) -// == forall X . X in [0, guardLimit-1) => X in [-1, guardLimit-1) +// forall X . guardStart + X u< guardLimit && +// latchStart + X u< latchLimit => +// guardStart + X + 1 u< guardLimit +// == forall X . guardStart + X u< guardLimit && +// latchStart + X u< latchStart + guardLimit - 1 - guardStart => +// guardStart + X + 1 u< guardLimit +// == forall X . (guardStart + X) in [0, guardLimit) && +// (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) => +// (guardStart + X + 1) in [0, guardLimit) +// == forall X . X in [-guardStart, guardLimit - guardStart) && +// X in [-latchStart, guardLimit - 1 - guardStart) => +// X in [-guardStart - 1, guardLimit - guardStart - 1) // == true // // So the widened condition is: -// i.start u< guardLimit && latchLimit s<= guardLimit -// Similarly for sle condition the widened condition is: -// i.start u< guardLimit && latchLimit s< guardLimit +// guardStart u< guardLimit && +// latchStart + guardLimit - 1 - guardStart u>= latchLimit +// Similarly for ule condition the widened condition is: +// guardStart u< guardLimit && +// latchStart + guardLimit - 1 - guardStart u> latchLimit +// For slt condition the widened condition is: +// guardStart u< guardLimit && +// latchStart + guardLimit - 1 - guardStart s>= latchLimit +// For sle condition the widened condition is: +// guardStart u< guardLimit && +// latchStart + guardLimit - 1 - guardStart s> latchLimit // //===----------------------------------------------------------------------===// @@ -322,20 +324,37 @@ return None; } auto *RangeCheckIV = RangeCheck->IV; - auto *PostIncRangeCheckIV = RangeCheckIV->getPostIncExpr(*SE); - if (LatchCheck.IV != PostIncRangeCheckIV) { - DEBUG(dbgs() << "Post increment range check IV (" << *PostIncRangeCheckIV - << ") is not the same as latch IV (" << *LatchCheck.IV - << ")!\n"); + auto *Ty = RangeCheckIV->getType(); + if (Ty != LatchCheck.IV->getType()) { + DEBUG(dbgs() << "Type mismatch between range check and latch IVs!\n"); + return None; + } + if (!RangeCheckIV->isAffine()) { + DEBUG(dbgs() << "Range check IV is not affine!\n"); + return None; + } + auto *Step = RangeCheckIV->getStepRecurrence(*SE); + if (Step != LatchCheck.IV->getStepRecurrence(*SE)) { + DEBUG(dbgs() << "Range check and latch have IVs different steps!\n"); return None; } - assert(RangeCheckIV->getStepRecurrence(*SE)->isOne() && "must be one"); - const SCEV *Start = RangeCheckIV->getStart(); + assert(Step->isOne() && "must be one"); // Generate the widened condition: - // i.start u< guardLimit && latchLimit guardLimit + // guardStart u< guardLimit && + // latchLimit guardLimit - 1 - guardStart + latchStart // where depends on the latch condition predicate. See the file // header comment for the reasoning. + const SCEV *GuardStart = RangeCheckIV->getStart(); + const SCEV *GuardLimit = RangeCheck->Limit; + const SCEV *LatchStart = LatchCheck.IV->getStart(); + const SCEV *LatchLimit = LatchCheck.Limit; + + // guardLimit - guardStart + latchStart - 1 + const SCEV *RHS = + SE->getAddExpr(SE->getMinusSCEV(GuardLimit, GuardStart), + SE->getMinusSCEV(LatchStart, SE->getOne(Ty))); + ICmpInst::Predicate LimitCheckPred; switch (LatchCheck.Pred) { case ICmpInst::ICMP_ULT: @@ -354,18 +373,24 @@ llvm_unreachable("Unsupported loop latch!"); } + DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n"); + DEBUG(dbgs() << "RHS: " << *RHS << "\n"); + DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n"); + auto CanExpand = [this](const SCEV *S) { return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE); }; - if (!CanExpand(Start) || !CanExpand(LatchCheck.Limit) || - !CanExpand(RangeCheck->Limit)) + if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) || + !CanExpand(LatchLimit) || !CanExpand(RHS)) { + DEBUG(dbgs() << "Can't expand limit check!\n"); return None; + } Instruction *InsertAt = Preheader->getTerminator(); - auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, - LatchCheck.Limit, RangeCheck->Limit, InsertAt); + auto *LimitCheck = + expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, RHS, InsertAt); auto *FirstIterationCheck = expandCheck(Expander, Builder, RangeCheck->Pred, - Start, RangeCheck->Limit, InsertAt); + GuardStart, GuardLimit, InsertAt); return Builder.CreateAnd(FirstIterationCheck, LimitCheck); } Index: test/Transforms/LoopPredication/basic.ll =================================================================== --- test/Transforms/LoopPredication/basic.ll +++ test/Transforms/LoopPredication/basic.ll @@ -255,6 +255,119 @@ ret i32 %result } +define i32 @signed_loop_0_to_n_preincrement_latch_check(i32* %array, i32 %length, i32 %n) { +; CHECK-LABEL: @signed_loop_0_to_n_preincrement_latch_check +entry: + %tmp5 = icmp sle i32 %n, 0 + br i1 %tmp5, label %exit, label %loop.preheader + +loop.preheader: +; CHECK: loop.preheader: +; CHECK: [[length_minus_1:[^ ]+]] = add i32 %length, -1 +; CHECK-NEXT: [[limit_check:[^ ]+]] = icmp sle i32 %n, [[length_minus_1]] +; CHECK-NEXT: [[first_iteration_check:[^ ]+]] = icmp ult i32 0, %length +; CHECK-NEXT: [[wide_cond:[^ ]+]] = and i1 [[first_iteration_check]], [[limit_check]] +; CHECK-NEXT: br label %loop + br label %loop + +loop: +; CHECK: loop: +; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 [[wide_cond]], i32 9) [ "deopt"() ] + %loop.acc = phi i32 [ %loop.acc.next, %loop ], [ 0, %loop.preheader ] + %i = phi i32 [ %i.next, %loop ], [ 0, %loop.preheader ] + %within.bounds = icmp ult i32 %i, %length + call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ] + + %i.i64 = zext i32 %i to i64 + %array.i.ptr = getelementptr inbounds i32, i32* %array, i64 %i.i64 + %array.i = load i32, i32* %array.i.ptr, align 4 + %loop.acc.next = add i32 %loop.acc, %array.i + + %i.next = add i32 %i, 1 + %continue = icmp slt i32 %i, %n + br i1 %continue, label %loop, label %exit + +exit: + %result = phi i32 [ 0, %entry ], [ %loop.acc.next, %loop ] + ret i32 %result +} + +define i32 @signed_loop_0_to_n_sle_latch_offset_ult_check(i32* %array, i32 %length, i32 %n) { +; CHECK-LABEL: @signed_loop_0_to_n_sle_latch_offset_ult_check +entry: + %tmp5 = icmp sle i32 %n, 0 + br i1 %tmp5, label %exit, label %loop.preheader + +loop.preheader: +; CHECK: loop.preheader: +; CHECK: [[length_minus_1:[^ ]+]] = add i32 %length, -1 +; CHECK-NEXT: [[limit_check:[^ ]+]] = icmp slt i32 %n, [[length_minus_1]] +; CHECK-NEXT: [[first_iteration_check:[^ ]+]] = icmp ult i32 1, %length +; CHECK-NEXT: [[wide_cond:[^ ]+]] = and i1 [[first_iteration_check]], [[limit_check]] +; CHECK-NEXT: br label %loop + br label %loop + +loop: +; CHECK: loop: +; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 [[wide_cond]], i32 9) [ "deopt"() ] + %loop.acc = phi i32 [ %loop.acc.next, %loop ], [ 0, %loop.preheader ] + %i = phi i32 [ %i.next, %loop ], [ 0, %loop.preheader ] + %i.offset = add i32 %i, 1 + %within.bounds = icmp ult i32 %i.offset, %length + call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ] + + %i.i64 = zext i32 %i to i64 + %array.i.ptr = getelementptr inbounds i32, i32* %array, i64 %i.i64 + %array.i = load i32, i32* %array.i.ptr, align 4 + %loop.acc.next = add i32 %loop.acc, %array.i + + %i.next = add i32 %i, 1 + %continue = icmp sle i32 %i.next, %n + br i1 %continue, label %loop, label %exit + +exit: + %result = phi i32 [ 0, %entry ], [ %loop.acc.next, %loop ] + ret i32 %result +} + +define i32 @signed_loop_0_to_n_offset_sle_latch_offset_ult_check(i32* %array, i32 %length, i32 %n) { +; CHECK-LABEL: @signed_loop_0_to_n_offset_sle_latch_offset_ult_check +entry: + %tmp5 = icmp sle i32 %n, 0 + br i1 %tmp5, label %exit, label %loop.preheader + +loop.preheader: +; CHECK: loop.preheader: +; CHECK: [[limit_check:[^ ]+]] = icmp slt i32 %n, %length +; CHECK-NEXT: [[first_iteration_check:[^ ]+]] = icmp ult i32 1, %length +; CHECK-NEXT: [[wide_cond:[^ ]+]] = and i1 [[first_iteration_check]], [[limit_check]] +; CHECK-NEXT: br label %loop + br label %loop + +loop: +; CHECK: loop: +; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 [[wide_cond]], i32 9) [ "deopt"() ] + %loop.acc = phi i32 [ %loop.acc.next, %loop ], [ 0, %loop.preheader ] + %i = phi i32 [ %i.next, %loop ], [ 0, %loop.preheader ] + %i.offset = add i32 %i, 1 + %within.bounds = icmp ult i32 %i.offset, %length + call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ] + + %i.i64 = zext i32 %i to i64 + %array.i.ptr = getelementptr inbounds i32, i32* %array, i64 %i.i64 + %array.i = load i32, i32* %array.i.ptr, align 4 + %loop.acc.next = add i32 %loop.acc, %array.i + + %i.next = add i32 %i, 1 + %i.next.offset = add i32 %i.next, 1 + %continue = icmp sle i32 %i.next.offset, %n + br i1 %continue, label %loop, label %exit + +exit: + %result = phi i32 [ 0, %entry ], [ %loop.acc.next, %loop ] + ret i32 %result +} + define i32 @unsupported_latch_pred_loop_0_to_n(i32* %array, i32 %length, i32 %n) { ; CHECK-LABEL: @unsupported_latch_pred_loop_0_to_n entry: @@ -362,8 +475,88 @@ ret i32 %result } -define i32 @signed_loop_0_to_n_unrelated_iv_range_check(i32* %array, i32 %start, i32 %length, i32 %n) { -; CHECK-LABEL: @signed_loop_0_to_n_unrelated_iv_range_check +define i32 @signed_loop_start_to_n_offset_iv_range_check(i32* %array, i32 %start.i, + i32 %start.j, i32 %length, + i32 %n) { +; CHECK-LABEL: @signed_loop_start_to_n_offset_iv_range_check +entry: + %tmp5 = icmp sle i32 %n, 0 + br i1 %tmp5, label %exit, label %loop.preheader + +loop.preheader: +; CHECK: loop.preheader: +; CHECK: [[length_plus_start_i:[^ ]+]] = add i32 %length, %start.i +; CHECK-NEXT: [[limit:[^ ]+]] = sub i32 [[length_plus_start_i]], %start.j +; CHECK-NEXT: [[limit_check:[^ ]+]] = icmp sle i32 %n, [[limit]] +; CHECK-NEXT: [[first_iteration_check:[^ ]+]] = icmp ult i32 %start.j, %length +; CHECK-NEXT: [[wide_cond:[^ ]+]] = and i1 [[first_iteration_check]], [[limit_check]] +; CHECK-NEXT: br label %loop + br label %loop + +loop: +; CHECK: loop: +; CHECK: call void (i1, ...) @llvm.experimental.guard(i1 [[wide_cond]], i32 9) [ "deopt"() ] + %loop.acc = phi i32 [ %loop.acc.next, %loop ], [ 0, %loop.preheader ] + %i = phi i32 [ %i.next, %loop ], [ %start.i, %loop.preheader ] + %j = phi i32 [ %j.next, %loop ], [ %start.j, %loop.preheader ] + + %within.bounds = icmp ult i32 %j, %length + call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ] + + %i.i64 = zext i32 %i to i64 + %array.i.ptr = getelementptr inbounds i32, i32* %array, i64 %i.i64 + %array.i = load i32, i32* %array.i.ptr, align 4 + %loop.acc.next = add i32 %loop.acc, %array.i + + %j.next = add i32 %j, 1 + %i.next = add i32 %i, 1 + %continue = icmp slt i32 %i.next, %n + br i1 %continue, label %loop, label %exit + +exit: + %result = phi i32 [ 0, %entry ], [ %loop.acc.next, %loop ] + ret i32 %result +} + +define i32 @signed_loop_0_to_n_different_iv_types(i32* %array, i16 %length, i32 %n) { +; CHECK-LABEL: @signed_loop_0_to_n_different_iv_types +entry: + %tmp5 = icmp sle i32 %n, 0 + br i1 %tmp5, label %exit, label %loop.preheader + +loop.preheader: +; CHECK: loop.preheader: +; CHECK-NEXT: br label %loop + br label %loop + +loop: +; CHECK: loop: +; CHECK: %within.bounds = icmp ult i16 %j, %length +; CHECK-NEXT: call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ] + %loop.acc = phi i32 [ %loop.acc.next, %loop ], [ 0, %loop.preheader ] + %i = phi i32 [ %i.next, %loop ], [ 0, %loop.preheader ] + %j = phi i16 [ %j.next, %loop ], [ 0, %loop.preheader ] + + %within.bounds = icmp ult i16 %j, %length + call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ] + + %i.i64 = zext i32 %i to i64 + %array.i.ptr = getelementptr inbounds i32, i32* %array, i64 %i.i64 + %array.i = load i32, i32* %array.i.ptr, align 4 + %loop.acc.next = add i32 %loop.acc, %array.i + + %j.next = add i16 %j, 1 + %i.next = add i32 %i, 1 + %continue = icmp slt i32 %i.next, %n + br i1 %continue, label %loop, label %exit + +exit: + %result = phi i32 [ 0, %entry ], [ %loop.acc.next, %loop ] + ret i32 %result +} + +define i32 @signed_loop_0_to_n_different_iv_strides(i32* %array, i32 %length, i32 %n) { +; CHECK-LABEL: @signed_loop_0_to_n_different_iv_strides entry: %tmp5 = icmp sle i32 %n, 0 br i1 %tmp5, label %exit, label %loop.preheader @@ -379,7 +572,7 @@ ; CHECK-NEXT: call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ] %loop.acc = phi i32 [ %loop.acc.next, %loop ], [ 0, %loop.preheader ] %i = phi i32 [ %i.next, %loop ], [ 0, %loop.preheader ] - %j = phi i32 [ %j.next, %loop ], [ %start, %loop.preheader ] + %j = phi i32 [ %j.next, %loop ], [ 0, %loop.preheader ] %within.bounds = icmp ult i32 %j, %length call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ] @@ -389,7 +582,7 @@ %array.i = load i32, i32* %array.i.ptr, align 4 %loop.acc.next = add i32 %loop.acc, %array.i - %j.next = add nsw i32 %j, 1 + %j.next = add nsw i32 %j, 2 %i.next = add nsw i32 %i, 1 %continue = icmp slt i32 %i.next, %n br i1 %continue, label %loop, label %exit