Index: llvm/lib/Transforms/Scalar/LoopFlatten.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -94,6 +94,11 @@ // Whether this holds the flatten info before or after widening. bool Widened = false; + // Holds the old/narrow induction phis, i.e. the Phis before IV widening has + // been applied. This bookkeeping is used so we can skip some checks on these + // phi nodes. + SmallPtrSet OldInductionPHIs; + FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {}; }; @@ -263,6 +268,8 @@ // them specially when doing the transformation. if (&InnerPHI == FI.InnerInductionPHI) continue; + if (FI.Widened && FI.OldInductionPHIs.count(&InnerPHI)) + continue; // Each inner loop PHI node must have two incoming values/blocks - one // from the pre-header, and one from the latch. @@ -308,6 +315,8 @@ } for (PHINode &OuterPHI : FI.OuterLoop->getHeader()->phis()) { + if (FI.Widened && FI.OldInductionPHIs.count(&OuterPHI)) + continue; if (!SafeOuterPHIs.count(&OuterPHI)) { LLVM_DEBUG(dbgs() << "found unsafe PHI in outer loop: "; OuterPHI.dump()); return false; @@ -398,8 +407,8 @@ if (U == FI.InnerIncrement) continue; - // After widening the IVs, a trunc instruction might have been introduced, so - // look through truncs. + // After widening the IVs, a trunc instruction might have been introduced, + // so look through truncs. if (isa(U)) { if (!U->hasOneUse()) return false; @@ -424,11 +433,23 @@ // Matches the same pattern as above, except it also looks for truncs // on the phi, which can be the result of widening the induction variables. - bool IsAddTrunc = match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)), - m_Value(MatchedMul))) && - match(MatchedMul, - m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)), - m_Value(MatchedItCount))); + bool IsAddTrunc = + match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)), + m_Value(MatchedMul))) && + match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)), + m_Value(MatchedItCount))); + + if (!MatchedItCount) + return false; + // Look through extends if the IV has been widened. + if (FI.Widened && + (isa(MatchedItCount) || isa(MatchedItCount))) { + assert(MatchedItCount->getType() == FI.InnerInductionPHI->getType() && + "Unexpected type mismatch in types after widening"); + MatchedItCount = isa(MatchedItCount) + ? dyn_cast(MatchedItCount)->getOperand(0) + : dyn_cast(MatchedItCount)->getOperand(0); + } if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) { LLVM_DEBUG(dbgs() << "Use is optimisable\n"); @@ -668,14 +689,11 @@ } SCEVExpander Rewriter(*SE, DL, "loopflatten"); - SmallVector WideIVs; SmallVector DeadInsts; - WideIVs.push_back( {FI.InnerInductionPHI, MaxLegalType, false }); - WideIVs.push_back( {FI.OuterInductionPHI, MaxLegalType, false }); unsigned ElimExt = 0; unsigned Widened = 0; - for (const auto &WideIV : WideIVs) { + auto CreateWideIV = [&] (WideIVInfo WideIV, bool &Deleted) -> bool { PHINode *WidePhi = createWideIV(WideIV, LI, SE, Rewriter, DT, DeadInsts, ElimExt, Widened, true /* HasGuards */, true /* UsePostIncrementRanges */); @@ -683,11 +701,28 @@ return false; LLVM_DEBUG(dbgs() << "Created wide phi: "; WidePhi->dump()); LLVM_DEBUG(dbgs() << "Deleting old phi: "; WideIV.NarrowIV->dump()); - RecursivelyDeleteDeadPHINode(WideIV.NarrowIV); - } - // After widening, rediscover all the loop components. + Deleted = RecursivelyDeleteDeadPHINode(WideIV.NarrowIV); + return true; + }; + + bool Deleted; + if (!CreateWideIV({FI.InnerInductionPHI, MaxLegalType, false }, Deleted)) + return false; + // If the inner Phi node cannot be trivially deleted, we need to at least + // bring it in a consistent state. + if (!Deleted) + FI.InnerInductionPHI->removeIncomingValue(FI.InnerLoop->getLoopLatch()); + if (!CreateWideIV({FI.OuterInductionPHI, MaxLegalType, false }, Deleted)) + return false; + assert(Widened && "Widened IV expected"); FI.Widened = true; + + // Save the old/narrow induction phis, which we need to ignore in CheckPHIs. + FI.OldInductionPHIs.insert(FI.InnerInductionPHI); + FI.OldInductionPHIs.insert(FI.OuterInductionPHI); + + // After widening, rediscover all the loop components. return CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI); } Index: llvm/test/Transforms/LoopFlatten/widen-iv.ll =================================================================== --- llvm/test/Transforms/LoopFlatten/widen-iv.ll +++ llvm/test/Transforms/LoopFlatten/widen-iv.ll @@ -111,6 +111,7 @@ ; CHECK-NEXT: [[TMP0:%.*]] = sext i32 [[M]] to i64 ; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[M]] to i64 ; CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[N]] to i64 +; CHECK-NEXT: [[FLATTEN_TRIPCOUNT:%.*]] = mul i64 [[TMP0]], [[TMP2]] ; CHECK-NEXT: br label [[FOR_COND1_PREHEADER_US:%.*]] ; CHECK: for.cond1.preheader.us: ; CHECK-NEXT: [[INDVAR2:%.*]] = phi i64 [ [[INDVAR_NEXT3:%.*]], [[FOR_COND1_FOR_COND_CLEANUP3_CRIT_EDGE_US:%.*]] ], [ 0, [[FOR_COND1_PREHEADER_US_PREHEADER]] ] @@ -118,26 +119,27 @@ ; CHECK-NEXT: [[TMP3:%.*]] = mul nsw i64 [[INDVAR2]], [[TMP1]] ; CHECK-NEXT: [[MUL_US:%.*]] = mul nsw i32 [[I_018_US]], [[M]] ; CHECK-NEXT: [[TMP4:%.*]] = sext i32 [[MUL_US]] to i64 +; CHECK-NEXT: [[FLATTEN_TRUNCIV:%.*]] = trunc i64 [[INDVAR2]] to i32 ; CHECK-NEXT: br label [[FOR_BODY4_US:%.*]] ; CHECK: for.body4.us: -; CHECK-NEXT: [[INDVAR:%.*]] = phi i64 [ [[INDVAR_NEXT:%.*]], [[FOR_BODY4_US]] ], [ 0, [[FOR_COND1_PREHEADER_US]] ] -; CHECK-NEXT: [[J_016_US:%.*]] = phi i32 [ 0, [[FOR_COND1_PREHEADER_US]] ], [ [[INC_US:%.*]], [[FOR_BODY4_US]] ] +; CHECK-NEXT: [[INDVAR:%.*]] = phi i64 [ 0, [[FOR_COND1_PREHEADER_US]] ] +; CHECK-NEXT: [[J_016_US:%.*]] = phi i32 [ 0, [[FOR_COND1_PREHEADER_US]] ] ; CHECK-NEXT: [[TMP5:%.*]] = add nsw i64 [[INDVAR]], [[TMP3]] ; CHECK-NEXT: [[TMP6:%.*]] = sext i32 [[J_016_US]] to i64 ; CHECK-NEXT: [[TMP7:%.*]] = add nsw i64 [[TMP6]], [[TMP3]] ; CHECK-NEXT: [[ADD_US:%.*]] = add nsw i32 [[J_016_US]], [[MUL_US]] -; CHECK-NEXT: [[IDXPROM_US:%.*]] = sext i32 [[ADD_US]] to i64 -; CHECK-NEXT: [[ARRAYIDX_US:%.*]] = getelementptr inbounds i32, i32* [[A:%.*]], i64 [[TMP5]] +; CHECK-NEXT: [[IDXPROM_US:%.*]] = sext i32 [[FLATTEN_TRUNCIV]] to i64 +; CHECK-NEXT: [[ARRAYIDX_US:%.*]] = getelementptr inbounds i32, i32* [[A:%.*]], i64 [[INDVAR2]] ; CHECK-NEXT: [[TMP8:%.*]] = load i32, i32* [[ARRAYIDX_US]], align 4 ; CHECK-NEXT: tail call void @g(i32 [[TMP8]]) -; CHECK-NEXT: [[INDVAR_NEXT]] = add i64 [[INDVAR]], 1 -; CHECK-NEXT: [[INC_US]] = add nuw nsw i32 [[J_016_US]], 1 +; CHECK-NEXT: [[INDVAR_NEXT:%.*]] = add i64 [[INDVAR]], 1 +; CHECK-NEXT: [[INC_US:%.*]] = add nuw nsw i32 [[J_016_US]], 1 ; CHECK-NEXT: [[CMP2_US:%.*]] = icmp slt i64 [[INDVAR_NEXT]], [[TMP0]] -; CHECK-NEXT: br i1 [[CMP2_US]], label [[FOR_BODY4_US]], label [[FOR_COND1_FOR_COND_CLEANUP3_CRIT_EDGE_US]] +; CHECK-NEXT: br label [[FOR_COND1_FOR_COND_CLEANUP3_CRIT_EDGE_US]] ; CHECK: for.cond1.for.cond.cleanup3_crit_edge.us: ; CHECK-NEXT: [[INDVAR_NEXT3]] = add i64 [[INDVAR2]], 1 ; CHECK-NEXT: [[INC6_US]] = add nuw nsw i32 [[I_018_US]], 1 -; CHECK-NEXT: [[CMP_US:%.*]] = icmp slt i64 [[INDVAR_NEXT3]], [[TMP2]] +; CHECK-NEXT: [[CMP_US:%.*]] = icmp slt i64 [[INDVAR_NEXT3]], [[FLATTEN_TRIPCOUNT]] ; CHECK-NEXT: br i1 [[CMP_US]], label [[FOR_COND1_PREHEADER_US]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]] ; CHECK: for.cond1.preheader: ; CHECK-NEXT: [[I_018:%.*]] = phi i32 [ [[INC6:%.*]], [[FOR_COND1_PREHEADER]] ], [ 0, [[FOR_COND1_PREHEADER_PREHEADER]] ]