diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp --- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -1243,7 +1243,11 @@ InductionPHI->moveBefore(&InductionPHI->getParent()->front()); // Split at the place were the induction variable is - // incremented/decremented. + // incremented/decremented. This might result in instructions in the split + // block using instructions from the original block. In this case, we + // have to insert LCSSA phi nodes later on, because they split off block + // will be outside the inner loop. We do this after moving the existing + // LCSSA phis in adjustLoopBranches. // TODO: This splitting logic may not work always. Fix this. splitInnerLoopLatch(InnerIndexVar); LLVM_DEBUG(dbgs() << "splitInnerLoopLatch done\n"); @@ -1298,83 +1302,94 @@ } } -// Move Lcssa PHIs to the right place. -static void moveLCSSAPhis(BasicBlock *InnerExit, BasicBlock *InnerHeader, - BasicBlock *InnerLatch, BasicBlock *OuterHeader, - BasicBlock *OuterLatch, BasicBlock *OuterExit) { - - // Deal with LCSSA PHI nodes in the exit block of the inner loop, that are - // defined either in the header or latch. Those blocks will become header and - // latch of the new outer loop, and the only possible users can PHI nodes +// Move LCSSA PHIs to the right place and add PHIs for block split off the +// original inner loop latch. +static void fixLCSSAPhis(Loop *InnerLoop, Loop *OuterLoop, + BasicBlock *OrigInnerExit) { + BasicBlock *OrigInnerHeader = OuterLoop->getHeader(); + BasicBlock *OrigInnerLatch = OuterLoop->getLoopLatch(); + BasicBlock *OrigOuterHeader = InnerLoop->getHeader(); + BasicBlock *OrigOuterLatch = InnerLoop->getLoopLatch(); + // The new inner loop can have multiple exits. Here we are only interested + // in the exit block in the loop nest, which is the outer loop's latch after + // interchanging. + BasicBlock *InnerExit = OuterLoop->getLoopLatch(); + BasicBlock *InnerLatch = InnerLoop->getLoopLatch(); + BasicBlock *OuterLatch = OuterLoop->getLoopLatch(); + BasicBlock *OuterExit = OuterLoop->getExitBlock(); + + // Deal with LCSSA PHI nodes in the exit block of the original inner loop, + // that are defined either in the header or latch. Those blocks became header + // and latch of the new outer loop, and the only possible users can PHI nodes // in the exit block of the loop nest or the outer loop header (reduction - // PHIs, in that case, the incoming value must be defined in the inner loop - // header). We can just substitute the user with the incoming value and remove - // the PHI. - for (PHINode &P : make_early_inc_range(InnerExit->phis())) { + // PHIs, in that case, the incoming value must be defined in original the + // inner loop header). We can just substitute the user with the incoming value + // and remove the PHI. + for (PHINode &P : make_early_inc_range(OrigInnerExit->phis())) { assert(P.getNumIncomingValues() == 1 && "Only loops with a single exit are supported!"); // Incoming values are guaranteed be instructions currently. - auto IncI = cast(P.getIncomingValueForBlock(InnerLatch)); + auto IncI = cast(P.getIncomingValueForBlock(OrigInnerLatch)); // Skip phis with incoming values from the inner loop body, excluding the // header and latch. - if (IncI->getParent() != InnerLatch && IncI->getParent() != InnerHeader) + if (IncI->getParent() != OrigInnerLatch && + IncI->getParent() != OrigInnerHeader) continue; - assert(all_of(P.users(), - [OuterHeader, OuterExit, IncI, InnerHeader](User *U) { - return (cast(U)->getParent() == OuterHeader && - IncI->getParent() == InnerHeader) || - cast(U)->getParent() == OuterExit; - }) && - "Can only replace phis iff the uses are in the loop nest exit or " - "the incoming value is defined in the inner header (it will " - "dominate all loop blocks after interchanging)"); + assert( + all_of(P.users(), + [OrigOuterHeader, OuterExit, IncI, OrigInnerHeader](User *U) { + return (cast(U)->getParent() == OrigOuterHeader && + IncI->getParent() == OrigInnerHeader) || + cast(U)->getParent() == OuterExit; + }) && + "Can only replace phis iff the uses are in the loop nest exit or " + "the incoming value is defined in the original inner header (it will " + "dominate all loop blocks after interchanging)"); P.replaceAllUsesWith(IncI); P.eraseFromParent(); } SmallVector LcssaInnerExit; - for (PHINode &P : InnerExit->phis()) + for (PHINode &P : OrigInnerExit->phis()) LcssaInnerExit.push_back(&P); SmallVector LcssaInnerLatch; - for (PHINode &P : InnerLatch->phis()) + for (PHINode &P : OrigInnerLatch->phis()) LcssaInnerLatch.push_back(&P); - // Lcssa PHIs for values used outside the inner loop are in InnerExit. - // If a PHI node has users outside of InnerExit, it has a use outside the - // interchanged loop and we have to preserve it. We move these to - // InnerLatch, which will become the new exit block for the innermost - // loop after interchanging. + // Lcssa PHIs for values used outside the original inner loop are in + // OrigInnerExit. If a PHI node has users outside of OrigInnerExit, it has a + // use outside the interchanged loop and we have to preserve it. We move these + // to the new exit block of the inner loop. for (PHINode *P : LcssaInnerExit) - P->moveBefore(InnerLatch->getFirstNonPHI()); + P->moveBefore(InnerLoop->getExitBlock()->getFirstNonPHI()); - // If the inner loop latch contains LCSSA PHIs, those come from a child loop - // and we have to move them to the new inner latch. + // If the original inner loop latch contains LCSSA PHIs, those come from a + // child loop and we have to move them to the new successor in the inner loop. for (PHINode *P : LcssaInnerLatch) - P->moveBefore(InnerExit->getFirstNonPHI()); + P->moveBefore(OrigInnerExit->getFirstNonPHI()); // Deal with LCSSA PHI nodes in the loop nest exit block. For PHIs that have - // incoming values from the outer latch or header, we have to add a new PHI - // in the inner loop latch, which became the exit block of the outer loop, - // after interchanging. + // incoming values from the original outer latch or header, we have to add a + // new PHI to the new inner exit block. if (OuterExit) { for (PHINode &P : OuterExit->phis()) { if (P.getNumIncomingValues() != 1) continue; - // Skip Phis with incoming values not defined in the outer loop's header - // and latch. Also skip incoming phis defined in the latch. Those should - // already have been updated. + // Skip Phis with incoming values not defined in the original outer loop's + // header and latch. Also skip incoming phis defined in the latch. Those + // should already have been updated. auto I = dyn_cast(P.getIncomingValue(0)); - if (!I || ((I->getParent() != OuterLatch || isa(I)) && - I->getParent() != OuterHeader)) + if (!I || ((I->getParent() != OrigOuterLatch || isa(I)) && + I->getParent() != OrigOuterHeader)) continue; PHINode *NewPhi = dyn_cast(P.clone()); NewPhi->setIncomingValue(0, P.getIncomingValue(0)); - NewPhi->setIncomingBlock(0, OuterLatch); - NewPhi->insertBefore(InnerLatch->getFirstNonPHI()); + NewPhi->setIncomingBlock(0, OrigOuterLatch); + NewPhi->insertBefore(InnerLoop->getExitBlock()->getFirstNonPHI()); P.setIncomingValue(0, NewPhi); } } @@ -1382,7 +1397,37 @@ // Now adjust the incoming blocks for the LCSSA PHIs. // For PHIs moved from Inner's exit block, we need to replace Inner's latch // with the new latch. - InnerLatch->replacePhiUsesWith(InnerLatch, OuterLatch); + InnerExit->replacePhiUsesWith(OrigInnerLatch, InnerLatch); + + // We split the original inner loop latch earlier and it became the latch of + // the outer loop. There can be instructions in the + // split off latch that use instructions from the inner loop and we have to + // create LCSSA phis for them. For instructions with operands in the inner + // loop, we need to create new LCSSA PHIs and update the users outside of the + // inner loop. + SmallSetVector NeedsLcssaPHI; + for (Instruction &I : *OuterLatch) { + if (isa(&I)) + continue; + for (Use &Op : make_early_inc_range(I.operands())) { + Instruction *OpI = dyn_cast(Op.get()); + if (!OpI || !InnerLoop->contains(OpI)) + continue; + NeedsLcssaPHI.insert(OpI); + } + } + for (Instruction *NI : NeedsLcssaPHI) { + PHINode *LcssaPhi = PHINode::Create(NI->getType(), 1, "lcssa.", + InnerExit->getFirstNonPHI()); + LcssaPhi->addIncoming(NI, InnerLatch); + for (auto UI = NI->use_begin(), UE = NI->use_end(); UI != UE;) { + Use &U = *UI++; + if (U.getUser() == LcssaPhi || + InnerLoop->contains(cast(U.getUser()))) + continue; + U.set(LcssaPhi); + } + } } bool LoopInterchangeTransform::adjustLoopBranches() { @@ -1480,9 +1525,9 @@ DT->applyUpdates(DTUpdates); restructureLoops(OuterLoop, InnerLoop, InnerLoopPreHeader, OuterLoopPreHeader); - - moveLCSSAPhis(InnerLoopLatchSuccessor, InnerLoopHeader, InnerLoopLatch, - OuterLoopHeader, OuterLoopLatch, InnerLoop->getExitBlock()); + // Update loop variables after updating LoopInfo. + std::swap(OuterLoop, InnerLoop); + fixLCSSAPhis(InnerLoop, OuterLoop, InnerLoopLatchSuccessor); // For PHIs in the exit block of the outer loop, outer's latch has been // replaced by Inners'. OuterLoopLatchSuccessor->replacePhiUsesWith(OuterLoopLatch, InnerLoopLatch); diff --git a/llvm/test/Transforms/LoopInterchange/pr43176-preserve-lcssa.ll b/llvm/test/Transforms/LoopInterchange/pr43176-preserve-lcssa.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopInterchange/pr43176-preserve-lcssa.ll @@ -0,0 +1,70 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -loop-interchange -verify-loop-lcssa -verify-dom-info -S %s | FileCheck %s + +@b = external dso_local global [5 x i32], align 16 + +define void @d() { +; CHECK-LABEL: @d( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[FOR_BODY2_PREHEADER:%.*]] +; CHECK: for.body.preheader: +; CHECK-NEXT: br label [[FOR_BODY:%.*]] +; CHECK: for.body: +; CHECK-NEXT: [[INC41:%.*]] = phi i32 [ [[INC4:%.*]], [[FOR_INC3:%.*]] ], [ undef, [[FOR_BODY_PREHEADER:%.*]] ] +; CHECK-NEXT: [[IDXPROM:%.*]] = sext i32 [[INC41]] to i64 +; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds [5 x i32], [5 x i32]* @b, i64 0, i64 [[IDXPROM]] +; CHECK-NEXT: br label [[FOR_BODY2_SPLIT:%.*]] +; CHECK: for.body2.preheader: +; CHECK-NEXT: br label [[FOR_BODY2:%.*]] +; CHECK: for.body2: +; CHECK-NEXT: [[LSR_IV:%.*]] = phi i32 [ [[LSR_IV_NEXT:%.*]], [[FOR_INC_SPLIT:%.*]] ], [ 1, [[FOR_BODY2_PREHEADER]] ] +; CHECK-NEXT: br label [[FOR_BODY_PREHEADER]] +; CHECK: for.body2.split: +; CHECK-NEXT: br label [[FOR_INC:%.*]] +; CHECK: for.inc: +; CHECK-NEXT: [[TMP0:%.*]] = load i32, i32* [[ARRAYIDX]], align 4 +; CHECK-NEXT: store i32 undef, i32* [[ARRAYIDX]], align 4 +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[LSR_IV]], 4 +; CHECK-NEXT: br label [[FOR_COND1_FOR_END_CRIT_EDGE:%.*]] +; CHECK: for.inc.split: +; CHECK-NEXT: [[LCSSA_:%.*]] = phi i1 [ [[CMP]], [[FOR_INC3]] ] +; CHECK-NEXT: [[LSR_IV_NEXT]] = add nuw nsw i32 [[LSR_IV]], 1 +; CHECK-NEXT: br i1 [[LCSSA_]], label [[FOR_BODY2]], label [[FOR_COND_FOR_END5_CRIT_EDGE:%.*]] +; CHECK: for.cond1.for.end_crit_edge: +; CHECK-NEXT: br label [[FOR_INC3]] +; CHECK: for.inc3: +; CHECK-NEXT: [[INC4]] = add nsw i32 [[INC41]], 1 +; CHECK-NEXT: br i1 false, label [[FOR_BODY]], label [[FOR_INC_SPLIT]] +; CHECK: for.cond.for.end5_crit_edge: +; CHECK-NEXT: ret void +; +entry: + br label %for.body + +for.body: ; preds = %for.inc3, %entry + %inc41 = phi i32 [ %inc4, %for.inc3 ], [ undef, %entry ] + br label %for.body2 + +for.body2: ; preds = %for.inc, %for.body + %lsr.iv = phi i32 [ %lsr.iv.next, %for.inc ], [ 1, %for.body ] + br label %for.inc + +for.inc: ; preds = %for.body2 + %idxprom = sext i32 %inc41 to i64 + %arrayidx = getelementptr inbounds [5 x i32], [5 x i32]* @b, i64 0, i64 %idxprom + %0 = load i32, i32* %arrayidx, align 4 + store i32 undef, i32* %arrayidx, align 4 + %cmp = icmp slt i32 %lsr.iv, 4 + %lsr.iv.next = add nuw nsw i32 %lsr.iv, 1 + br i1 %cmp, label %for.body2, label %for.cond1.for.end_crit_edge + +for.cond1.for.end_crit_edge: ; preds = %for.inc + br label %for.inc3 + +for.inc3: ; preds = %for.cond1.for.end_crit_edge + %inc4 = add nsw i32 %inc41, 1 + br i1 undef, label %for.body, label %for.cond.for.end5_crit_edge + +for.cond.for.end5_crit_edge: ; preds = %for.inc3 + ret void +}