diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -167,8 +167,7 @@ // count computed by SCEV then this is either because the trip count variable // has been widened (then leave the trip count as it is), or because it is a // constant and another transformation has changed the compare, e.g. - // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, then we don't flatten - // the loop (yet). + // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1. TripCount = Compare->getOperand(1); const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); if (isa(BackedgeTakenCount)) { @@ -176,12 +175,22 @@ return false; } const SCEV *SCEVTripCount = SE->getTripCountFromExitCount(BackedgeTakenCount); - if (SE->getSCEV(TripCount) != SCEVTripCount) { - if (!IsWidened) { + if (SE->getSCEV(TripCount) != SCEVTripCount && !IsWidened) { + ConstantInt *RHS = dyn_cast(TripCount); + if (!RHS) { LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); return false; } - auto TripCountInst = dyn_cast(TripCount); + // The L->isCanonical check above ensures we only get here if the loop + // increments by 1 on each iteration, so the RHS of the Compare is + // tripcount-1 (i.e equivalent to the backedge taken count). + assert(SE->getSCEV(RHS) == BackedgeTakenCount && + "Expected RHS of compare to be equal to the backedge taken count"); + ConstantInt *One = ConstantInt::get(RHS->getType(), 1); + TripCount = ConstantInt::get(TripCount->getContext(), + RHS->getValue() + One->getValue()); + } else if (SE->getSCEV(TripCount) != SCEVTripCount) { + auto *TripCountInst = dyn_cast(TripCount); if (!TripCountInst) { LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); return false; @@ -368,6 +377,13 @@ U = *U->user_begin(); } + // If the use is in the compare (which is also the condition of the inner + // branch) then the compare has been altered by another transformation e.g + // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is + // a constant. Ignore this use as the compare gets removed later anyway. + if (U == FI.InnerBranch->getCondition()) + continue; + LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump()); Value *MatchedMul; diff --git a/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll b/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll --- a/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll +++ b/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll @@ -341,38 +341,111 @@ ret i32 10 } -; When the loop trip count is a constant (e.g. 20) and the step size is -; 1, InstCombine causes the transformation icmp ult i32 %inc, 20 -> -; icmp ult i32 %j, 19. In this case a valid trip count is not found so -; the loop is not flattened. -define i32 @test9(i32* nocapture %A) { +; test_10, test_11 and test_12 are for the case when the +; inner trip count is a constant, then the InstCombine pass +; makes the transformation icmp ult i32 %inc, tripcount -> +; icmp ult i32 %j, tripcount-step. + +; test_10: The step is not 1. +define i32 @test_10(i32* nocapture %A) { entry: br label %for.cond1.preheader for.cond1.preheader: - %i.017 = phi i32 [ 0, %entry ], [ %inc6, %for.cond.cleanup3 ] + %i.017 = phi i32 [ 0, %entry ], [ %inc, %for.cond.cleanup3 ] %mul = mul i32 %i.017, 20 br label %for.body4 +for.body4: + %j.016 = phi i32 [ 0, %for.cond1.preheader ], [ %add5, %for.body4 ] + %add = add i32 %j.016, %mul + %arrayidx = getelementptr inbounds i32, i32* %A, i32 %add + store i32 30, i32* %arrayidx, align 4 + %add5 = add nuw nsw i32 %j.016, 2 + %cmp2 = icmp ult i32 %j.016, 18 + br i1 %cmp2, label %for.body4, label %for.cond.cleanup3 + for.cond.cleanup3: - %inc6 = add i32 %i.017, 1 - %cmp = icmp ult i32 %inc6, 11 + %inc = add i32 %i.017, 1 + %cmp = icmp ult i32 %inc, 11 br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup +for.cond.cleanup: + %0 = load i32, i32* %A, align 4 + ret i32 %0 +} + +; test_11: The inner inducation variable is used in a compare which +; isn't the condition of the inner branch. +define i32 @test_11(i32* nocapture %A) { +entry: + br label %for.cond1.preheader + +for.cond1.preheader: + %i.020 = phi i32 [ 0, %entry ], [ %inc7, %for.cond.cleanup3 ] + %mul = mul i32 %i.020, 20 + br label %for.body4 + for.body4: - %j.016 = phi i32 [ 0, %for.cond1.preheader ], [ %inc, %for.body4 ] - %add = add i32 %j.016, %mul + %j.019 = phi i32 [ 0, %for.cond1.preheader ], [ %inc, %for.body4 ] + %cmp5 = icmp ult i32 %j.019, 5 + %cond = select i1 %cmp5, i32 30, i32 15 + %add = add i32 %j.019, %mul %arrayidx = getelementptr inbounds i32, i32* %A, i32 %add - store i32 30, i32* %arrayidx, align 4 - %inc = add nuw nsw i32 %j.016, 1 - %cmp2 = icmp ult i32 %j.016, 19 + store i32 %cond, i32* %arrayidx, align 4 + %inc = add nuw nsw i32 %j.019, 1 + %cmp2 = icmp ult i32 %j.019, 19 br i1 %cmp2, label %for.body4, label %for.cond.cleanup3 +for.cond.cleanup3: + %inc7 = add i32 %i.020, 1 + %cmp = icmp ult i32 %inc7, 11 + br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup + for.cond.cleanup: %0 = load i32, i32* %A, align 4 ret i32 %0 } +; test_12: Incoming phi node value for preheader is a variable +define i32 @test_12(i32* %A) { +entry: + br label %while.cond1.preheader + +while.cond1.preheader: + %j.017 = phi i32 [ 0, %entry ], [ %j.1, %while.end ] + %i.016 = phi i32 [ 0, %entry ], [ %inc4, %while.end ] + %mul = mul i32 %i.016, 20 + %cmp214 = icmp ult i32 %j.017, 20 + br i1 %cmp214, label %while.body3.preheader, label %while.end + +while.body3.preheader: + br label %while.body3 + +while.body3: + %j.115 = phi i32 [ %inc, %while.body3 ], [ %j.017, %while.body3.preheader ] + %add = add i32 %j.115, %mul + %arrayidx = getelementptr inbounds i32, i32* %A, i32 %add + store i32 30, i32* %arrayidx, align 4 + %inc = add nuw nsw i32 %j.115, 1 + %cmp2 = icmp ult i32 %j.115, 19 + br i1 %cmp2, label %while.body3, label %while.end.loopexit + +while.end.loopexit: + %inc.lcssa = phi i32 [ %inc, %while.body3 ] + br label %while.end + +while.end: + %j.1 = phi i32 [ %j.017, %while.cond1.preheader], [ %inc.lcssa, %while.end.loopexit ] + %inc4 = add i32 %i.016, 1 + %cmp = icmp ult i32 %inc4, 11 + br i1 %cmp, label %while.cond1.preheader, label %while.end5 + +while.end5: + %0 = load i32, i32* %A, align 4 + ret i32 %0 +} + ; Outer loop conditional phi define i32 @e() { entry: @@ -683,5 +756,36 @@ br i1 %cmp4, label %for.body7, label %for.cond.cleanup6.loopexit } +; Invalid trip count +define void @invalid_tripCount(i8* %a, i32 %b, i32 %c, i32 %initial-mutations, i32 %statemutations) { +entry: + %iszero = icmp eq i32 %b, 0 + br i1 %iszero, label %for.empty, label %for.loopinit +for.loopinit: + br label %for.loopbody.outer +for.loopbody.outer: + %for.count.ph = phi i32 [ %c, %for.refetch ], [ %b, %for.loopinit ] + br label %for.loopbody +for.loopbody: + %for.index = phi i32 [ %1, %for.notmutated ], [ 0, %for.loopbody.outer ] + %0 = icmp eq i32 %statemutations, %initial-mutations + br i1 %0, label %for.notmutated, label %for.mutated +for.mutated: + call void @objc_enumerationMutation(i8* %a) + br label %for.notmutated +for.notmutated: + %1 = add nuw i32 %for.index, 1 + %2 = icmp ult i32 %1, %for.count.ph + br i1 %2, label %for.loopbody, label %for.refetch +for.refetch: + %3 = icmp eq i32 %c, 0 + br i1 %3, label %for.empty.loopexit, label %for.loopbody.outer +for.empty.loopexit: + br label %for.empty +for.empty: + ret void +} + +declare void @objc_enumerationMutation(i8*) declare dso_local void @f(i32*) declare dso_local void @g(...) diff --git a/llvm/test/Transforms/LoopFlatten/loop-flatten.ll b/llvm/test/Transforms/LoopFlatten/loop-flatten.ll --- a/llvm/test/Transforms/LoopFlatten/loop-flatten.ll +++ b/llvm/test/Transforms/LoopFlatten/loop-flatten.ll @@ -586,6 +586,59 @@ ret i32 10 } +; When the inner loop trip count is a constant and the step +; is 1, the InstCombine pass causes the transformation e.g. +; icmp ult i32 %inc, 20 -> icmp ult i32 %j, 19. This doesn't +; match the pattern (OuterPHI * InnerTripCount) + InnerPHI but +; we should still flatten the loop as the compare is removed +; later anyway. +define i32 @test9(i32* nocapture %A) { +entry: + br label %for.cond1.preheader +; CHECK-LABEL: test9 +; CHECK: entry: +; CHECK: %flatten.tripcount = mul i32 20, 11 +; CHECK: br label %for.cond1.preheader + +for.cond1.preheader: + %i.017 = phi i32 [ 0, %entry ], [ %inc6, %for.cond.cleanup3 ] + %mul = mul i32 %i.017, 20 + br label %for.body4 +; CHECK: for.cond1.preheader: +; CHECK: %i.017 = phi i32 [ 0, %entry ], [ %inc6, %for.cond.cleanup3 ] +; CHECK: %mul = mul i32 %i.017, 20 +; CHECK: br label %for.body4 + +for.cond.cleanup3: + %inc6 = add i32 %i.017, 1 + %cmp = icmp ult i32 %inc6, 11 + br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup +; CHECK: for.cond.cleanup3: +; CHECK: %inc6 = add i32 %i.017, 1 +; CHECK: %cmp = icmp ult i32 %inc6, %flatten.tripcount +; CHECK: br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup + +for.body4: + %j.016 = phi i32 [ 0, %for.cond1.preheader ], [ %inc, %for.body4 ] + %add = add i32 %j.016, %mul + %arrayidx = getelementptr inbounds i32, i32* %A, i32 %add + store i32 30, i32* %arrayidx, align 4 + %inc = add nuw nsw i32 %j.016, 1 + %cmp2 = icmp ult i32 %j.016, 19 + br i1 %cmp2, label %for.body4, label %for.cond.cleanup3 +; CHECK: for.body4 +; CHECK: %j.016 = phi i32 [ 0, %for.cond1.preheader ] +; CHECK: %add = add i32 %j.016, %mul +; CHECK: %arrayidx = getelementptr inbounds i32, i32* %A, i32 %i.017 +; CHECK: store i32 30, i32* %arrayidx, align 4 +; CHECK: %inc = add nuw nsw i32 %j.016, 1 +; CHECK: %cmp2 = icmp ult i32 %j.016, 19 +; CHECK: br label %for.cond.cleanup3 + +for.cond.cleanup: + %0 = load i32, i32* %A, align 4 + ret i32 %0 +} declare i32 @func(i32)