diff --git a/llvm/include/llvm/Analysis/LoopNestAnalysis.h b/llvm/include/llvm/Analysis/LoopNestAnalysis.h --- a/llvm/include/llvm/Analysis/LoopNestAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopNestAnalysis.h @@ -139,6 +139,11 @@ return all_of(Loops, [](const Loop *L) { return L->isRotatedForm(); }); } + /// Return the function to which the loop-nest belongs. + Function *getParent() const { + return Loops.front()->getHeader()->getParent(); + } + StringRef getName() const { return Loops.front()->getName(); } protected: diff --git a/llvm/include/llvm/Transforms/Scalar/LoopInterchange.h b/llvm/include/llvm/Transforms/Scalar/LoopInterchange.h --- a/llvm/include/llvm/Transforms/Scalar/LoopInterchange.h +++ b/llvm/include/llvm/Transforms/Scalar/LoopInterchange.h @@ -15,7 +15,7 @@ namespace llvm { struct LoopInterchangePass : public PassInfoMixin { - PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, + PreservedAnalyses run(LoopNest &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U); }; diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp --- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -449,7 +449,15 @@ return processLoopList(populateWorklist(*L)); } - bool isComputableLoopNest(LoopVector LoopList) { + bool run(LoopNest &LN) { + const auto &LoopList = LN.getLoops(); + for (unsigned I = 1; I < LoopList.size(); ++I) + if (LoopList[I]->getParentLoop() != LoopList[I - 1]) + return false; + return processLoopList(LoopList); + } + + bool isComputableLoopNest(ArrayRef LoopList) { for (Loop *L : LoopList) { const SCEV *ExitCountOuter = SE->getBackedgeTakenCount(L); if (isa(ExitCountOuter)) { @@ -468,13 +476,13 @@ return true; } - unsigned selectLoopForInterchange(const LoopVector &LoopList) { + unsigned selectLoopForInterchange(ArrayRef LoopList) { // TODO: Add a better heuristic to select the loop to be interchanged based // on the dependence matrix. Currently we select the innermost loop. return LoopList.size() - 1; } - bool processLoopList(LoopVector LoopList) { + bool processLoopList(ArrayRef LoopList) { bool Changed = false; unsigned LoopNestDepth = LoopList.size(); if (LoopNestDepth < 2) { @@ -515,14 +523,12 @@ unsigned SelecLoopId = selectLoopForInterchange(LoopList); // Move the selected loop outwards to the best possible position. + Loop *LoopToBeInterchanged = LoopList[SelecLoopId]; for (unsigned i = SelecLoopId; i > 0; i--) { - bool Interchanged = processLoop(LoopList[i], LoopList[i - 1], i, i - 1, - LoopNestExit, DependencyMatrix); + bool Interchanged = processLoop(LoopToBeInterchanged, LoopList[i - 1], i, + i - 1, LoopNestExit, DependencyMatrix); if (!Interchanged) return Changed; - // Loops interchanged reflect the same in LoopList - std::swap(LoopList[i - 1], LoopList[i]); - // Update the DependencyMatrix interChangeDependencies(DependencyMatrix, i, i - 1); #ifdef DUMP_DEP_MATRICIES @@ -539,7 +545,6 @@ std::vector> &DependencyMatrix) { LLVM_DEBUG(dbgs() << "Processing InnerLoopId = " << InnerLoopId << " and OuterLoopId = " << OuterLoopId << "\n"); - LoopInterchangeLegality LIL(OuterLoop, InnerLoop, SE, ORE); if (!LIL.canInterchangeLoops(InnerLoopId, OuterLoopId, DependencyMatrix)) { LLVM_DEBUG(dbgs() << "Not interchanging loops. Cannot prove legality.\n"); @@ -1680,14 +1685,15 @@ return new LoopInterchangeLegacyPass(); } -PreservedAnalyses LoopInterchangePass::run(Loop &L, LoopAnalysisManager &AM, +PreservedAnalyses LoopInterchangePass::run(LoopNest &LN, + LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U) { - Function &F = *L.getHeader()->getParent(); + Function &F = *LN.getParent(); DependenceInfo DI(&F, &AR.AA, &AR.SE, &AR.LI); OptimizationRemarkEmitter ORE(&F); - if (!LoopInterchange(&AR.SE, &AR.LI, &DI, &AR.DT, &ORE).run(&L)) + if (!LoopInterchange(&AR.SE, &AR.LI, &DI, &AR.DT, &ORE).run(LN)) return PreservedAnalyses::all(); return getLoopPassPreservedAnalyses(); } diff --git a/llvm/test/Transforms/LoopInterchange/interchanged-loop-nest-3.ll b/llvm/test/Transforms/LoopInterchange/interchanged-loop-nest-3.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopInterchange/interchanged-loop-nest-3.ll @@ -0,0 +1,56 @@ +; REQUIRES: asserts +; RUN: opt < %s -basic-aa -loop-interchange -verify-dom-info -verify-loop-info \ +; RUN: -S -debug 2>&1 | FileCheck %s + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@D = common global [100 x [100 x [100 x i32]]] zeroinitializer + +;; Test for interchange in loop nest greater than 2. +;; for(int i=0;i<100;i++) +;; for(int j=0;j<100;j++) +;; for(int k=0;k<100;k++) +;; D[k][j][i] = D[k][j][i]+t; + +; CHECK: Processing InnerLoopId = 2 and OuterLoopId = 1 +; CHECK: Loops interchanged. + +; CHECK: Processing InnerLoopId = 1 and OuterLoopId = 0 +; CHECK: Loops interchanged. + +define void @interchange_08(i32 %t){ +entry: + br label %for.cond1.preheader + +for.cond1.preheader: ; preds = %for.inc15, %entry + %i.028 = phi i32 [ 0, %entry ], [ %inc16, %for.inc15 ] + br label %for.cond4.preheader + +for.cond4.preheader: ; preds = %for.inc12, %for.cond1.preheader + %j.027 = phi i32 [ 0, %for.cond1.preheader ], [ %inc13, %for.inc12 ] + br label %for.body6 + +for.body6: ; preds = %for.body6, %for.cond4.preheader + %k.026 = phi i32 [ 0, %for.cond4.preheader ], [ %inc, %for.body6 ] + %arrayidx8 = getelementptr inbounds [100 x [100 x [100 x i32]]], [100 x [100 x [100 x i32]]]* @D, i32 0, i32 %k.026, i32 %j.027, i32 %i.028 + %0 = load i32, i32* %arrayidx8 + %add = add nsw i32 %0, %t + store i32 %add, i32* %arrayidx8 + %inc = add nuw nsw i32 %k.026, 1 + %exitcond = icmp eq i32 %inc, 100 + br i1 %exitcond, label %for.inc12, label %for.body6 + +for.inc12: ; preds = %for.body6 + %inc13 = add nuw nsw i32 %j.027, 1 + %exitcond29 = icmp eq i32 %inc13, 100 + br i1 %exitcond29, label %for.inc15, label %for.cond4.preheader + +for.inc15: ; preds = %for.inc12 + %inc16 = add nuw nsw i32 %i.028, 1 + %exitcond30 = icmp eq i32 %inc16, 100 + br i1 %exitcond30, label %for.end17, label %for.cond1.preheader + +for.end17: ; preds = %for.inc15 + ret void +}