diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp --- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -334,7 +334,8 @@ /// Check if the loop structure is understood. We do not handle triangular /// loops for now. - bool isLoopStructureUnderstood(PHINode *InnerInductionVar); + bool isLoopStructureUnderstood( + const SmallVectorImpl &InnerInductionVar); bool currentLimitations(); @@ -342,6 +343,10 @@ return OuterInnerReductions; } + const SmallVectorImpl &getInnerLoopInductions() const { + return InnerLoopInductions; + } + private: bool tightlyNested(Loop *Outer, Loop *Inner); bool containsUnsafeInstructions(BasicBlock *BB); @@ -365,6 +370,9 @@ /// Set of reduction PHIs taking part of a reduction across the inner and /// outer loop. SmallPtrSet OuterInnerReductions; + + /// Set of inner loop induction PHIs + SmallVector InnerLoopInductions; }; /// LoopInterchangeProfitability checks if it is profitable to interchange the @@ -636,24 +644,26 @@ } bool LoopInterchangeLegality::isLoopStructureUnderstood( - PHINode *InnerInduction) { - unsigned Num = InnerInduction->getNumOperands(); + const SmallVectorImpl &Inductions) { BasicBlock *InnerLoopPreheader = InnerLoop->getLoopPreheader(); - for (unsigned i = 0; i < Num; ++i) { - Value *Val = InnerInduction->getOperand(i); - if (isa(Val)) - continue; - Instruction *I = dyn_cast(Val); - if (!I) - return false; - // TODO: Handle triangular loops. - // e.g. for(int i=0;igetIncomingBlock(IncomBlockIndx) == - InnerLoopPreheader && - !OuterLoop->isLoopInvariant(I)) { - return false; + for (PHINode *InnerInduction : Inductions) { + unsigned Num = InnerInduction->getNumOperands(); + for (unsigned i = 0; i < Num; ++i) { + Value *Val = InnerInduction->getOperand(i); + if (isa(Val)) + continue; + Instruction *I = dyn_cast(Val); + if (!I) + return false; + // TODO: Handle triangular loops. + // e.g. for(int i=0;igetIncomingBlock(IncomBlockIndx) == + InnerLoopPreheader && + !OuterLoop->isLoopInvariant(I)) { + return false; + } } } @@ -683,8 +693,9 @@ // InnerInduction, or a binary operator that involves // InnerInduction and a constant. std::function IsPathToIndVar; - IsPathToIndVar = [&InnerInduction, &IsPathToIndVar](Value *V) -> bool { - if (V == InnerInduction) + IsPathToIndVar = [&Inductions, &IsPathToIndVar](Value *V) -> bool { + if (std::find(Inductions.begin(), Inductions.end(), V) != + Inductions.end()) return true; if (isa(V)) return true; @@ -699,6 +710,13 @@ return false; }; + // In case of multiple inner loop indvars, it is okay if LHS and RHS + // are both inner indvar related variables. + if (IsPathToIndVar(Op0) && IsPathToIndVar(Op1)) + return true; + + // Otherwise we check for the cmp instruction that compares an inner indvar + // related variable (Left) with a outer loop invariant (Right). if (IsPathToIndVar(Op0) && !isa(Op0)) { Left = Op0; Right = Op1; @@ -815,7 +833,6 @@ return true; } - PHINode *InnerInductionVar; SmallVector Inductions; if (!findInductionAndReductions(OuterLoop, Inductions, InnerLoop)) { LLVM_DEBUG( @@ -860,24 +877,8 @@ return true; } - // TODO: Currently we handle only loops with 1 induction variable. - if (Inductions.size() != 1) { - LLVM_DEBUG( - dbgs() << "We currently only support loops with 1 induction variable." - << "Failed to interchange due to current limitation\n"); - ORE->emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "MultiInductionInner", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "Only inner loops with 1 induction variable can be " - "interchanged currently."; - }); - return true; - } - InnerInductionVar = Inductions.pop_back_val(); - // TODO: Triangular loops are not handled for now. - if (!isLoopStructureUnderstood(InnerInductionVar)) { + if (!isLoopStructureUnderstood(Inductions)) { LLVM_DEBUG(dbgs() << "Loop structure not understood by pass\n"); ORE->emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedStructureInner", @@ -888,35 +889,39 @@ return true; } - // TODO: Current limitation: Since we split the inner loop latch at the point - // were induction variable is incremented (induction.next); We cannot have - // more than 1 user of induction.next since it would result in broken code - // after split. - // e.g. - // for(i=0;igetIncomingBlock(0) == InnerLoopPreHeader) - InnerIndexVarInc = - dyn_cast(InnerInductionVar->getIncomingValue(1)); - else - InnerIndexVarInc = - dyn_cast(InnerInductionVar->getIncomingValue(0)); - - if (!InnerIndexVarInc) { - LLVM_DEBUG( - dbgs() << "Did not find an instruction to increment the induction " - << "variable.\n"); - ORE->emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "NoIncrementInInner", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "The inner loop does not increment the induction variable."; - }); - return true; + InnerLoopInductions = Inductions; + SmallPtrSet InnerIndexVarIncs; + for (PHINode *InnerInductionVar : InnerLoopInductions) { + // TODO: Current limitation: Since we split the inner loop latch at the + // point where induction variable is incremented (induction.next); We cannot + // have more than 1 user of induction.next since it would result in broken + // code after split. e.g. for(i=0;igetIncomingBlock(0) == InnerLoopPreHeader) + InnerIndexVarInc = + dyn_cast(InnerInductionVar->getIncomingValue(1)); + else + InnerIndexVarInc = + dyn_cast(InnerInductionVar->getIncomingValue(0)); + + if (!InnerIndexVarInc) { + LLVM_DEBUG( + dbgs() << "Did not find an instruction to increment the induction " + << "variable: "; + InnerInductionVar->dump();); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NoIncrementInInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "The inner loop does not increment the induction variable."; + }); + return true; + } + InnerIndexVarIncs.insert(InnerIndexVarInc); } // Since we split the inner loop latch on this induction variable. Make sure @@ -932,7 +937,10 @@ // We found an instruction. If this is not induction variable then it is not // safe to split this loop latch. - if (!I.isIdenticalTo(InnerIndexVarInc)) { + if (std::none_of(InnerIndexVarIncs.begin(), InnerIndexVarIncs.end(), + [&I](const Instruction *InnerIndexVarInc) { + return InnerIndexVarInc->isIdenticalTo(&I); + })) { LLVM_DEBUG(dbgs() << "Found unsupported instructions between induction " << "variable increment and branch.\n"); ORE->emit([&]() { @@ -1347,25 +1355,25 @@ bool LoopInterchangeTransform::transform() { bool Transformed = false; - Instruction *InnerIndexVar; if (InnerLoop->getSubLoops().empty()) { BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); LLVM_DEBUG(dbgs() << "Splitting the inner loop latch\n"); - PHINode *InductionPHI = getInductionVariable(InnerLoop, SE); - if (!InductionPHI) { + auto &InductionPHIs = LIL.getInnerLoopInductions(); + if (InductionPHIs.empty()) { LLVM_DEBUG(dbgs() << "Failed to find the point to split loop latch \n"); return false; } - if (InductionPHI->getIncomingBlock(0) == InnerLoopPreHeader) - InnerIndexVar = dyn_cast(InductionPHI->getIncomingValue(1)); - else - InnerIndexVar = dyn_cast(InductionPHI->getIncomingValue(0)); - - // Ensure that InductionPHI is the first Phi node. - if (&InductionPHI->getParent()->front() != InductionPHI) - InductionPHI->moveBefore(&InductionPHI->getParent()->front()); + SmallVector InnerIndexVarList; + for (PHINode *CurInductionPHI : InductionPHIs) { + if (CurInductionPHI->getIncomingBlock(0) == InnerLoopPreHeader) + InnerIndexVarList.push_back( + dyn_cast(CurInductionPHI->getIncomingValue(1))); + else + InnerIndexVarList.push_back( + dyn_cast(CurInductionPHI->getIncomingValue(0))); + } // Create a new latch block for the inner loop. We split at the // current latch's terminator and then move the condition and all @@ -1377,7 +1385,7 @@ SmallSetVector WorkList; unsigned i = 0; - auto MoveInstructions = [&i, &WorkList, this, InductionPHI, NewLatch]() { + auto MoveInstructions = [&i, &WorkList, this, &InductionPHIs, NewLatch]() { for (; i < WorkList.size(); i++) { // Duplicate instruction and move it the new latch. Update uses that // have been moved. @@ -1389,7 +1397,9 @@ for (Use &U : llvm::make_early_inc_range(WorkList[i]->uses())) { Instruction *UserI = cast(U.getUser()); if (!InnerLoop->contains(UserI->getParent()) || - UserI->getParent() == NewLatch || UserI == InductionPHI) + UserI->getParent() == NewLatch || + std::find(InductionPHIs.begin(), InductionPHIs.end(), UserI) != + InductionPHIs.end()) U.set(NewI); } // Add operands of moved instruction to the worklist, except if they are @@ -1398,7 +1408,8 @@ Instruction *OpI = dyn_cast(Op); if (!OpI || this->LI->getLoopFor(OpI->getParent()) != this->InnerLoop || - OpI == InductionPHI) + std::find(InductionPHIs.begin(), InductionPHIs.end(), OpI) != + InductionPHIs.end()) continue; WorkList.insert(OpI); } @@ -1412,7 +1423,8 @@ if (CondI) WorkList.insert(CondI); MoveInstructions(); - WorkList.insert(cast(InnerIndexVar)); + for (Instruction *InnerIndexVar : InnerIndexVarList) + WorkList.insert(cast(InnerIndexVar)); MoveInstructions(); // Splits the inner loops phi nodes out into a separate basic block. @@ -1685,7 +1697,6 @@ updateSuccessor(InnerLoopLatchPredecessorBI, InnerLoopLatch, InnerLoopLatchSuccessor, DTUpdates); - if (OuterLoopLatchBI->getSuccessor(0) == OuterLoopHeader) OuterLoopLatchSuccessor = OuterLoopLatchBI->getSuccessor(1); else @@ -1710,9 +1721,13 @@ auto &OuterInnerReductions = LIL.getOuterInnerReductions(); // Now update the reduction PHIs in the inner and outer loop headers. SmallVector InnerLoopPHIs, OuterLoopPHIs; - for (PHINode &PHI : InnerLoopHeader->phis()) - if (OuterInnerReductions.contains(&PHI)) - InnerLoopPHIs.push_back(cast(&PHI)); + auto &InnerLoopInductions = LIL.getInnerLoopInductions(); + for (PHINode &PHI : InnerLoopHeader->phis()) { + if (std::find(InnerLoopInductions.begin(), InnerLoopInductions.end(), + &PHI) != InnerLoopInductions.end()) + continue; + InnerLoopPHIs.push_back(cast(&PHI)); + } for (PHINode &PHI : OuterLoopHeader->phis()) if (OuterInnerReductions.contains(&PHI)) OuterLoopPHIs.push_back(cast(&PHI)); @@ -1725,6 +1740,7 @@ assert(OuterInnerReductions.count(PHI) && "Expected a reduction PHI node"); } for (PHINode *PHI : InnerLoopPHIs) { + LLVM_DEBUG(dbgs() << "Inner loop reduction PHIs:\n"; PHI->dump();); PHI->moveBefore(OuterLoopHeader->getFirstNonPHI()); assert(OuterInnerReductions.count(PHI) && "Expected a reduction PHI node"); } diff --git a/llvm/test/Transforms/LoopInterchange/interchangeable-innerloop-multiple-indvars.ll b/llvm/test/Transforms/LoopInterchange/interchangeable-innerloop-multiple-indvars.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopInterchange/interchangeable-innerloop-multiple-indvars.ll @@ -0,0 +1,94 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -basic-aa -loop-interchange -verify-dom-info -verify-loop-info -verify-scev -verify-loop-lcssa -S | FileCheck %s + +@c = common dso_local local_unnamed_addr global i32 0, align 4 +@d = common dso_local local_unnamed_addr global i32 0, align 4 +@e = common dso_local local_unnamed_addr global i32 0, align 4 +@b = common dso_local local_unnamed_addr global [2 x [10 x i32]] zeroinitializer, align 4 +@a = common dso_local local_unnamed_addr global i32 0, align 4 + +;; int a, c, d, e; +;; int b[2][10]; +;; void fn1() { +;; for (; c; c++) { +;; d = 5; +;; e = 5; +;; for (; d, e; d--, e--) +;; a |= b[d][c + 9]; +;; } +;; } + +define void @test1() { +; CHECK-LABEL: @test1( +; CHECK: for.body: +; CHECK: [[INDVARS_IV16:%.*]] = phi i64 [ [[TMP1:%.*]], [[FOR_BODY_LR_PH:%.*]] ], [ [[INDVARS_IV_NEXT17:%.*]], [[FOR_INC7:%.*]] ] +; CHECK: [[OR12:%.*]] = phi i32 [ [[OR:%.*]], [[FOR_INC7]] ], [ [[OR_LCSSA15:%.*]], [[FOR_BODY_LR_PH]] ] +; CHECK: for.body3: +; CHECK: [[INDVARS_IV:%.*]] = phi i64 [ [[TMP6:%.*]], [[FOR_BODY3_SPLIT:%.*]] ], [ 5, [[FOR_BODY3_PREHEADER:%.*]] ] +; CHECK: [[DEC11:%.*]] = phi i32 [ [[TMP7:%.*]], [[FOR_BODY3_SPLIT]] ], [ 5, [[FOR_BODY3_PREHEADER]] ] +; CHECK: [[OR_LCSSA15]] = phi i32 [ [[A_PROMOTED14:%.*]], [[FOR_BODY3_PREHEADER]] ], [ [[OR_LCSSA:%.*]], [[FOR_BODY3_SPLIT]] ] +; CHECK: for.body3.split: +; CHECK: [[OR_LCSSA]] = phi i32 [ [[OR]], [[FOR_INC7]] ] +; CHECK: [[TMP6]] = add nsw i64 [[INDVARS_IV]], -1 +; CHECK: [[TMP7]] = add nsw i32 [[DEC11]], -1 +; CHECK: [[TMP8:%.*]] = icmp eq i32 [[TMP7]], 0 +; CHECK: for.inc7: +; CHECK: [[INDVARS_IV_NEXT17]] = add nsw i64 [[INDVARS_IV16]], 1 +; CHECK: for.cond.for.end8_crit_edge: +; CHECK: [[OR_LCSSA_LCSSA:%.*]] = phi i32 [ [[OR_LCSSA]], [[FOR_BODY3_SPLIT]] ] +; CHECK: [[DEC6_LCSSA_LCSSA:%.*]] = phi i32 [ 0, [[FOR_BODY3_SPLIT]] ] +; CHECK: for.end8: +; CHECK: ret void +; +entry: + %.pr = load i32, i32* @c, align 4 + %tobool10 = icmp eq i32 %.pr, 0 + br i1 %tobool10, label %for.end8, label %for.body.lr.ph + +for.body.lr.ph: ; preds = %entry + %a.promoted14 = load i32, i32* @a, align 4 + %0 = sext i32 %.pr to i64 + %1 = sub i32 -1, %.pr + %2 = zext i32 %1 to i64 + %3 = add i64 %0, %2 + br label %for.body + +for.body: ; preds = %for.body.lr.ph, %for.inc7 + %indvars.iv16 = phi i64 [ %0, %for.body.lr.ph ], [ %indvars.iv.next17, %for.inc7 ] + %or.lcssa15 = phi i32 [ %a.promoted14, %for.body.lr.ph ], [ %or.lcssa, %for.inc7 ] + %4 = add nsw i64 %indvars.iv16, 9 + br label %for.body3 + +for.body3: ; preds = %for.body, %for.body3 + %or12 = phi i32 [ %or.lcssa15, %for.body ], [ %or, %for.body3 ] + %indvars.iv = phi i64 [ 5, %for.body ], [ %indvars.iv.next, %for.body3 ] + %dec11 = phi i32 [ 5, %for.body ], [ %dec, %for.body3 ] + %arrayidx5 = getelementptr inbounds [2 x [10 x i32]], [2 x [10 x i32]]* @b, i64 0, i64 %indvars.iv, i64 %4 + %5 = load i32, i32* %arrayidx5, align 4 + %or = or i32 %or12, %5 + %indvars.iv.next = add nsw i64 %indvars.iv, -1 + %dec = add nsw i32 %dec11, -1 + %tobool2 = icmp eq i32 %dec, 0 + br i1 %tobool2, label %for.inc7, label %for.body3 + +for.inc7: ; preds = %for.body3 + %or.lcssa = phi i32 [ %or, %for.body3 ] + %indvars.iv.next17 = add nsw i64 %indvars.iv16, 1 + %6 = trunc i64 %indvars.iv.next17 to i32 + %tobool = icmp eq i32 %6, 0 + br i1 %tobool, label %for.cond.for.end8_crit_edge, label %for.body + +for.cond.for.end8_crit_edge: ; preds = %for.inc7 + %or.lcssa.lcssa = phi i32 [ %or.lcssa, %for.inc7 ] + %dec6.lcssa.lcssa = phi i32 [ 0, %for.inc7 ] + %7 = add i64 %3, 1 + %8 = trunc i64 %7 to i32 + store i32 0, i32* @d, align 4 + store i32 %dec6.lcssa.lcssa, i32* @e, align 4 + store i32 %or.lcssa.lcssa, i32* @a, align 4 + store i32 %8, i32* @c, align 4 + br label %for.end8 + +for.end8: ; preds = %for.cond.for.end8_crit_edge, %entry + ret void +}