diff --git a/llvm/include/llvm/Analysis/LoopNestAnalysis.h b/llvm/include/llvm/Analysis/LoopNestAnalysis.h --- a/llvm/include/llvm/Analysis/LoopNestAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopNestAnalysis.h @@ -66,6 +66,18 @@ static const BasicBlock &skipEmptyBlockUntil(const BasicBlock *From, const BasicBlock *End); + /// Determine whether the loops structure violates basic requirements for + /// perfect nesting: + /// - the inner loop should be the outer loop's only child + /// - the outer loop header should 'flow' into the inner loop preheader + /// or jump around the inner loop to the outer loop latch + /// - if the inner loop latch exits the inner loop, it should 'flow' into + /// the outer loop latch. + /// Returns true if the loop structure satisfies the basic requirements and + /// false otherwise. + static bool checkLoopsStructure(const Loop &OuterLoop, const Loop &InnerLoop, + ScalarEvolution &SE); + /// Return the outermost loop in the loop nest. Loop &getOutermostLoop() const { return *Loops.front(); } @@ -144,6 +156,13 @@ return Loops.front()->getHeader()->getParent(); } + /// Return whether all loops in the loopnest (except for the innermost one) + /// have exactly one subloop and thus the loopnest is "totally-nested". + bool isTotallyNested() const { + return all_of(make_range(Loops.begin(), std::prev(Loops.end())), + [](const Loop *L) { return L->getSubLoops().size() == 1; }); + } + StringRef getName() const { return Loops.front()->getName(); } protected: diff --git a/llvm/lib/Analysis/LoopNestAnalysis.cpp b/llvm/lib/Analysis/LoopNestAnalysis.cpp --- a/llvm/lib/Analysis/LoopNestAnalysis.cpp +++ b/llvm/lib/Analysis/LoopNestAnalysis.cpp @@ -24,18 +24,6 @@ static const char *VerboseDebug = DEBUG_TYPE "-verbose"; #endif -/// Determine whether the loops structure violates basic requirements for -/// perfect nesting: -/// - the inner loop should be the outer loop's only child -/// - the outer loop header should 'flow' into the inner loop preheader -/// or jump around the inner loop to the outer loop latch -/// - if the inner loop latch exits the inner loop, it should 'flow' into -/// the outer loop latch. -/// Returns true if the loop structure satisfies the basic requirements and -/// false otherwise. -static bool checkLoopsStructure(const Loop &OuterLoop, const Loop &InnerLoop, - ScalarEvolution &SE); - //===----------------------------------------------------------------------===// // LoopNest implementation // @@ -230,8 +218,8 @@ return (BB == End) ? *End : *PredBB; } -static bool checkLoopsStructure(const Loop &OuterLoop, const Loop &InnerLoop, - ScalarEvolution &SE) { +bool LoopNest::checkLoopsStructure(const Loop &OuterLoop, const Loop &InnerLoop, + ScalarEvolution &SE) { // The inner loop must be the only outer loop's child. if ((OuterLoop.getSubLoops().size() != 1) || (InnerLoop.getParentLoop() != &OuterLoop)) diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp --- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -450,11 +450,9 @@ } bool run(LoopNest &LN) { - const auto &LoopList = LN.getLoops(); - for (unsigned I = 1; I < LoopList.size(); ++I) - if (LoopList[I]->getParentLoop() != LoopList[I - 1]) - return false; - return processLoopList(LoopList); + if (!LN.isTotallyNested()) + return false; + return processLoopList(LN.getLoops()); } bool isComputableLoopNest(ArrayRef LoopList) { @@ -588,26 +586,14 @@ } bool LoopInterchangeLegality::tightlyNested(Loop *OuterLoop, Loop *InnerLoop) { - BasicBlock *OuterLoopHeader = OuterLoop->getHeader(); - BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); - BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch(); - LLVM_DEBUG(dbgs() << "Checking if loops are tightly nested\n"); - // A perfectly nested loop will not have any branch in between the outer and - // inner block i.e. outer header will branch to either inner preheader and - // outerloop latch. - BranchInst *OuterLoopHeaderBI = - dyn_cast(OuterLoopHeader->getTerminator()); - if (!OuterLoopHeaderBI) + if (!LoopNest::checkLoopsStructure(*OuterLoop, *InnerLoop, *SE)) return false; + BasicBlock *OuterLoopHeader = OuterLoop->getHeader(); + BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); + BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch(); - for (BasicBlock *Succ : successors(OuterLoopHeaderBI)) - if (Succ != InnerLoopPreHeader && Succ != InnerLoop->getHeader() && - Succ != OuterLoopLatch) - return false; - - LLVM_DEBUG(dbgs() << "Checking instructions in Loop header and Loop latch\n"); // We do not have any basic block in between now make sure the outer header // and outer loop latch doesn't contain any unsafe instructions. if (containsUnsafeInstructions(OuterLoopHeader) || diff --git a/llvm/test/Transforms/LoopInterchange/pr48113.ll b/llvm/test/Transforms/LoopInterchange/pr48113.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopInterchange/pr48113.ll @@ -0,0 +1,94 @@ +; RUN: opt -S -passes='loop-interchange' -debug-only=loop-interchange < %s 2>&1 > /dev/null | FileCheck %s +@.str = private unnamed_addr constant [5 x i8] c"%lX\0A\00", align 1 +@d = dso_local local_unnamed_addr global i8 0, align 1 +@a = dso_local local_unnamed_addr global i8 0, align 1 +@b = dso_local local_unnamed_addr global [1 x [2 x i32]] zeroinitializer, align 4 +@c = dso_local local_unnamed_addr global [1 x [9 x i8]] zeroinitializer, align 1 +@e = dso_local local_unnamed_addr global i8* null, align 8 +@g = dso_local local_unnamed_addr global i16 0, align 2 +@f = internal unnamed_addr global i32 0, align 4 + +; CHECK: Processing InnerLoopId = 1 and OuterLoopId = 0 +; CHECK-NEXT: Checking if loops are tightly nested +; CHECK-NEXT: Loops not tightly nested +; CHECK-NEXT: Not interchanging loops. Cannot prove legality + +; Function Attrs: nofree nounwind optsize uwtable +define dso_local i32 @main() local_unnamed_addr #0 { +entry: + %.pr.i = load i32, i32* @f, align 4, !tbaa !2 + %cmp3.i = icmp ult i32 %.pr.i, 3 + br i1 %cmp3.i, label %for.cond1.preheader.lr.ph.i, label %h.exit + +for.cond1.preheader.lr.ph.i: ; preds = %entry + %0 = load i8, i8* @a, align 1, !tbaa !6 + %tobool.not.i = icmp eq i8 %0, 0 + %1 = load i8*, i8** @e, align 8 + %2 = zext i32 %.pr.i to i64 + br label %for.cond1.preheader.i + +for.cond1.preheader.i: ; preds = %if.end.i, %for.cond1.preheader.lr.ph.i + %indvars.iv4.i = phi i64 [ %2, %for.cond1.preheader.lr.ph.i ], [ %indvars.iv.next5.i, %if.end.i ] + br label %for.body3.i + +for.body3.i: ; preds = %for.body3.i, %for.cond1.preheader.i + %indvars.iv.i = phi i64 [ 0, %for.cond1.preheader.i ], [ %indvars.iv.next.i, %for.body3.i ] + %arrayidx5.i = getelementptr inbounds [1 x [9 x i8]], [1 x [9 x i8]]* @c, i64 0, i64 %indvars.iv.i, i64 %indvars.iv4.i + %3 = load i8, i8* %arrayidx5.i, align 1, !tbaa !6 + %conv.i = sext i8 %3 to i32 + %arrayidx8.i = getelementptr inbounds [1 x [2 x i32]], [1 x [2 x i32]]* @b, i64 0, i64 %indvars.iv.i, i64 1 + store i32 %conv.i, i32* %arrayidx8.i, align 4, !tbaa !2 + %indvars.iv.next.i = add nuw nsw i64 %indvars.iv.i, 1 + %exitcond.not.i = icmp eq i64 %indvars.iv.next.i, 3 + br i1 %exitcond.not.i, label %for.end.i, label %for.body3.i, !llvm.loop !7 + +for.end.i: ; preds = %for.body3.i + %4 = load i8, i8* @d, align 1, !tbaa !6 + %inc9.i = add i8 %4, 1 + store i8 %inc9.i, i8* @d, align 1, !tbaa !6 + br i1 %tobool.not.i, label %if.end.i, label %if.then.i + +if.then.i: ; preds = %for.end.i + %5 = load i8, i8* %1, align 1, !tbaa !6 + %conv10.i = sext i8 %5 to i16 + store i16 %conv10.i, i16* @g, align 2, !tbaa !9 + br label %if.end.i + +if.end.i: ; preds = %if.then.i, %for.end.i + %indvars.iv.next5.i = add nuw nsw i64 %indvars.iv4.i, 1 + %exitcond6.not.i = icmp eq i64 %indvars.iv.next5.i, 3 + br i1 %exitcond6.not.i, label %for.cond.for.end13_crit_edge.i, label %for.cond1.preheader.i, !llvm.loop !11 + +for.cond.for.end13_crit_edge.i: ; preds = %if.end.i + store i32 3, i32* @f, align 4, !tbaa !2 + br label %h.exit + +h.exit: ; preds = %entry, %for.cond.for.end13_crit_edge.i + %6 = load i8, i8* @d, align 1, !tbaa !6 + %conv = sext i8 %6 to i32 + %call1 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([5 x i8], [5 x i8]* @.str, i64 0, i64 0), i32 %conv) #2 + ret i32 0 +} + +; Function Attrs: nofree nounwind optsize +declare dso_local noundef i32 @printf(i8* nocapture noundef readonly, ...) local_unnamed_addr #1 + +attributes #0 = { nofree nounwind optsize uwtable "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" } +attributes #1 = { nofree nounwind optsize "frame-pointer"="none" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" } +attributes #2 = { optsize } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 13.0.0 (git@github.com:llvm/llvm-project.git ffba9e596d09a1d41f83102756e145b59d3f8495)"} +!2 = !{!3, !3, i64 0} +!3 = !{!"int", !4, i64 0} +!4 = !{!"omnipotent char", !5, i64 0} +!5 = !{!"Simple C/C++ TBAA"} +!6 = !{!4, !4, i64 0} +!7 = distinct !{!7, !8} +!8 = !{!"llvm.loop.mustprogress"} +!9 = !{!10, !10, i64 0} +!10 = !{!"short", !4, i64 0} +!11 = distinct !{!11, !8}