Index: llvm/lib/Transforms/Scalar/LoopInterchange.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -584,38 +584,51 @@ }); } -bool LoopInterchangeLegality::tightlyNested(Loop *OuterLoop, Loop *InnerLoop) { - BasicBlock *OuterLoopHeader = OuterLoop->getHeader(); - BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); - BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch(); +bool LoopInterchangeLegality::tightlyNested(Loop *LoopOuter, Loop *LoopInner) { + BasicBlock *LoopOuterHeader = LoopOuter->getHeader(); + BasicBlock *LoopInnerPreHeader = LoopInner->getLoopPreheader(); + BasicBlock *LoopOuterLatch = LoopOuter->getLoopLatch(); + BasicBlock *LoopInnerExit = LoopInner->getExitBlock(); LLVM_DEBUG(dbgs() << "Checking if loops are tightly nested\n"); // A perfectly nested loop will not have any branch in between the outer and // inner block i.e. outer header will branch to either inner preheader and - // outerloop latch. - BranchInst *OuterLoopHeaderBI = - dyn_cast(OuterLoopHeader->getTerminator()); - if (!OuterLoopHeaderBI) + // outer latch. + BranchInst *LoopOuterHeaderBI = + dyn_cast(LoopOuterHeader->getTerminator()); + if (!LoopOuterHeaderBI) return false; - for (BasicBlock *Succ : successors(OuterLoopHeaderBI)) - if (Succ != InnerLoopPreHeader && Succ != InnerLoop->getHeader() && - Succ != OuterLoopLatch) + for (BasicBlock *Succ : successors(LoopOuterHeaderBI)) + if (Succ != LoopInnerPreHeader && Succ != LoopInner->getHeader() && + Succ != LoopOuterLatch) return false; + // Ensure the inner loop exit block leads to the outer loop latch. + const BasicBlock *SuccInner = LoopInnerExit->getSingleSuccessor(); + if ((!SuccInner || (SuccInner != LoopOuterLatch)) && + LoopInnerExit != LoopOuterLatch) { + LLVM_DEBUG( + dbgs() << "Inner loop exit block " << *LoopInnerExit + << " does not directly lead/equal to the outer loop latch.\n";); + return false; + } + LLVM_DEBUG(dbgs() << "Checking instructions in Loop header and Loop latch\n"); // We do not have any basic block in between now make sure the outer header // and outer loop latch doesn't contain any unsafe instructions. - if (containsUnsafeInstructions(OuterLoopHeader) || - containsUnsafeInstructions(OuterLoopLatch)) + if (containsUnsafeInstructions(LoopOuterHeader) || + containsUnsafeInstructions(LoopOuterLatch)) return false; - // Also make sure the inner loop preheader does not contain any unsafe - // instructions. Note that all instructions in the preheader will be moved to - // the outer loop header when interchanging. - if (InnerLoopPreHeader != OuterLoopHeader && - containsUnsafeInstructions(InnerLoopPreHeader)) + // Also make sure the inner loop preheader and inner loop exit does not + // contain any unsafe instructions. Note that all instructions in the + // preheader will be moved to the outer loop header when interchanging. Inner + // loop exit will be moved to the new inner loop after interchanging. + if (LoopInnerPreHeader != LoopOuterHeader && + (containsUnsafeInstructions(LoopInnerPreHeader) || + containsUnsafeInstructions(LoopInnerExit))) return false; LLVM_DEBUG(dbgs() << "Loops are perfectly nested\n"); @@ -990,17 +1003,23 @@ } // Check if the loops are tightly nested. - if (!tightlyNested(OuterLoop, InnerLoop)) { - LLVM_DEBUG(dbgs() << "Loops not tightly nested\n"); - ORE->emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "NotTightlyNested", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "Cannot interchange loops because they are not tightly " - "nested."; - }); - return false; - } + Loop *LoopInner = InnerLoop; + Loop *LoopOuter = nullptr; + do { + LoopOuter = LoopInner->getParentLoop(); + if (!tightlyNested(LoopOuter, LoopInner)) { + LLVM_DEBUG(dbgs() << "Loops not tightly nested\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NotTightlyNested", + LoopInner->getStartLoc(), + LoopInner->getHeader()) + << "Cannot interchange loops because they are not tightly " + "nested."; + }); + return false; + } + LoopInner = LoopOuter; + } while (LoopOuter != OuterLoop); if (!areInnerLoopExitPHIsSupported(OuterLoop, InnerLoop, OuterInnerReductions)) { Index: llvm/test/Transforms/LoopInterchange/imperfectly-nested.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/LoopInterchange/imperfectly-nested.ll @@ -0,0 +1,67 @@ +; REQUIRES: asserts +; RUN: opt < %s -basicaa -loop-interchange -S -debug 2>&1 | FileCheck %s + +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" +target triple = "aarch64-unknown-linux-gnu" + +@a = common dso_local local_unnamed_addr global i8 0, align 1 +@b = common dso_local local_unnamed_addr global [1 x [2 x i32]] zeroinitializer, align 4 +@c = common dso_local local_unnamed_addr global [1 x [9 x i8]] zeroinitializer, align 1 + +;; The following case is not perfectly nested, +;; should not be interchanged. +;; +;; char a; +;; int b[][2]; +;; char c[][9]; +;; int h() { +;; for (; f <= 2; f++) { +;; for (int j = 0; j <= 2; j++) { +;; b[j][1] = c[j][f]; +;; } +;; if (a) +;; } +;; return 0; +;; } + +; CHECK: Loops not tightly nested +; CHECK: Not interchanging loops. Cannot prove legality. + +define i32 @h() { +outer.preheader: + %0 = load i8, i8* @a, align 1 + %tobool.i = icmp eq i8 %0, 0 + br label %outer.header + +outer.header: ; preds = %if.end.i, %outer.preheader + %indvars.iv4.i = phi i64 [ 0, %outer.preheader ], [ %indvars.iv.next5.i, %if.end.i ] + br label %inner.header + +inner.header: ; preds = %inner.header, %outer.header + %indvars.iv.i = phi i64 [ 0, %outer.header ], [ %indvars.iv.next.i, %inner.header ] + %arrayidx5.i = getelementptr inbounds [1 x [9 x i8]], [1 x [9 x i8]]* @c, i64 0, i64 %indvars.iv.i, i64 %indvars.iv4.i + %1 = load i8, i8* %arrayidx5.i, align 1 + %conv.i = zext i8 %1 to i32 + %arrayidx8.i = getelementptr inbounds [1 x [2 x i32]], [1 x [2 x i32]]* @b, i64 0, i64 %indvars.iv.i, i64 1 + store i32 %conv.i, i32* %arrayidx8.i, align 4 + %indvars.iv.next.i = add nuw nsw i64 %indvars.iv.i, 1 + %exitcond.i = icmp eq i64 %indvars.iv.next.i, 3 + br i1 %exitcond.i, label %for.cond.cleanup.i, label %inner.header + +for.cond.cleanup.i: ; preds = %inner.header + br i1 %tobool.i, label %if.end.i, label %if.then.i + +if.then.i: ; preds = %for.cond.cleanup.i + br label %if.end.i + +if.end.i: ; preds = %if.then.i, %for.cond.cleanup.i + %indvars.iv.next5.i = add nuw nsw i64 %indvars.iv4.i, 1 + %exitcond6.i = icmp eq i64 %indvars.iv.next5.i, 3 + br i1 %exitcond6.i, label %outer.exit, label %outer.header + +outer.exit: ; preds = %if.end.i + br label %h.exit + +h.exit: ; preds = %outer.exit + ret i32 0 +}