Index: llvm/lib/Transforms/Scalar/LoopInterchange.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -296,8 +296,8 @@ class LoopInterchangeLegality { public: LoopInterchangeLegality(Loop *Outer, Loop *Inner, ScalarEvolution *SE, - OptimizationRemarkEmitter *ORE) - : OuterLoop(Outer), InnerLoop(Inner), SE(SE), ORE(ORE) {} + LoopInfo *LI, OptimizationRemarkEmitter *ORE) + : OuterLoop(Outer), InnerLoop(Inner), SE(SE), LI(LI), ORE(ORE) {} /// Check if the loops can be interchanged. bool canInterchangeLoops(unsigned InnerLoopId, unsigned OuterLoopId, @@ -335,7 +335,7 @@ Loop *OuterLoop; Loop *InnerLoop; - + LoopInfo *LI; ScalarEvolution *SE; /// Interface to emit optimization remarks. @@ -534,7 +534,7 @@ std::vector> &DependencyMatrix) { LLVM_DEBUG(dbgs() << "Processing InnerLoopId = " << InnerLoopId << " and OuterLoopId = " << OuterLoopId << "\n"); - LoopInterchangeLegality LIL(OuterLoop, InnerLoop, SE, ORE); + LoopInterchangeLegality LIL(OuterLoop, InnerLoop, SE, LI, ORE); if (!LIL.canInterchangeLoops(InnerLoopId, OuterLoopId, DependencyMatrix)) { LLVM_DEBUG(dbgs() << "Not interchanging loops. Cannot prove legality.\n"); return false; @@ -780,8 +780,27 @@ "Phis in loop header should have exactly 2 incoming values"); // Check if we have a PHI node in the outer loop that has a reduction // result from the inner loop as an incoming value. + // + // For multi-level nested loops, in order to find the reduction phi in + // InnerLoop, we need to start from the innermost loop and find the + // reduction phi there. Note that InnerLoop might not be the innermost + // loop for multi-level nested loops. Value *V = followLCSSA(PHI.getIncomingValueForBlock(L->getLoopLatch())); - PHINode *InnerRedPhi = findInnerReductionPhi(InnerLoop, V); + Loop *InnerMostLoop = InnerLoop; + while (!InnerMostLoop->getSubLoops().empty()) { + InnerMostLoop = InnerMostLoop->getSubLoops().front(); + } + PHINode *InnerRedPhi = findInnerReductionPhi(InnerMostLoop, V); + // Start with the reduction phi in the innermost loop and iteratively + // find the target reduction phi in InnerLoop. + Loop *CurLoop = InnerMostLoop; + if (InnerRedPhi && InnerRedPhi->getParent() != InnerLoop->getHeader()) { + while (InnerRedPhi->getParent() != InnerLoop->getHeader()) { + InnerRedPhi = cast(InnerRedPhi->getIncomingValueForBlock( + CurLoop->getLoopPreheader())); + CurLoop = LI->getLoopFor(InnerRedPhi->getParent()); + } + } if (!InnerRedPhi || !llvm::is_contained(InnerRedPhi->incoming_values(), &PHI)) { LLVM_DEBUG( Index: llvm/test/Transforms/LoopInterchange/pr43326-ideal-access-pattern.ll =================================================================== --- llvm/test/Transforms/LoopInterchange/pr43326-ideal-access-pattern.ll +++ llvm/test/Transforms/LoopInterchange/pr43326-ideal-access-pattern.ll @@ -61,4 +61,73 @@ %indvars.innermost.next = add nuw nsw i64 %indvars.innermost, 1 %exitcond.innermost = icmp ne i64 %indvars.innermost.next, 10 br i1 %exitcond.innermost, label %for.innermost, label %for.middle.latch +} + + +; Triply nested loop with reduction operations inside the loopnest, +; should be able to do interchange three times to get the ideal access pattern. +; +; unsigned f(int e[10][10][10], int f[10][10][10]) { +; unsigned x = 0; +; for (int a = 0; a < 10; a++) { +; for (int b = 0; b < 10; b++) { +; for (int c = 0; c < 10; c++) { +; x += e[c][b][a]; +; } +; } +; } +; return x; +; } + +; REMARKS: --- !Passed +; REMARKS-NEXT: Pass: loop-interchange +; REMARKS-NEXT: Name: Interchanged +; REMARKS-NEXT: Function: pr43326-triply-nested-reduction +; REMARKS: --- !Passed +; REMARKS-NEXT: Pass: loop-interchange +; REMARKS-NEXT: Name: Interchanged +; REMARKS-NEXT: Function: pr43326-triply-nested-reduction +; REMARKS: --- !Passed +; REMARKS-NEXT: Pass: loop-interchange +; REMARKS-NEXT: Name: Interchanged +; REMARKS-NEXT: Function: pr43326-triply-nested-reduction +define i32 @pr43326-triply-nested-reduction([10 x [10 x i32]]* %e, [10 x [10 x i32]]* %f) { +entry: + br label %for.outermost.header + +for.outermost.header: ; preds = %entry, %for.outermost.latch + %indvars.outermost = phi i64 [ 0, %entry ], [ %indvars.outermost.next, %for.outermost.latch ] + %xor.outermost = phi i32 [ 0, %entry ], [ %xor.outermost.lcssa, %for.outermost.latch ] + br label %for.middle.header + +for.cond.cleanup: ; preds = %for.outermost.latch + %xor.outermost.lcssa.lcssa = phi i32 [ %xor.outermost.lcssa, %for.outermost.latch ] + ret i32 %xor.outermost.lcssa.lcssa + +for.middle.header: ; preds = %for.outermost.header, %for.middle.latch + %indvars.middle = phi i64 [ 0, %for.outermost.header ], [ %indvars.middle.next, %for.middle.latch ] + %xor.middle = phi i32 [ %xor.outermost, %for.outermost.header ], [ %xor.middle.lcssa, %for.middle.latch ] + br label %for.innermost + +for.outermost.latch: ; preds = %for.middle.latch + %xor.outermost.lcssa = phi i32 [ %xor.middle.lcssa, %for.middle.latch ] + %indvars.outermost.next = add nuw nsw i64 %indvars.outermost, 1 + %exitcond.outermost = icmp ne i64 %indvars.outermost.next, 10 + br i1 %exitcond.outermost, label %for.outermost.header, label %for.cond.cleanup + +for.middle.latch: ; preds = %for.innermost + %xor.middle.lcssa = phi i32 [ %xor.reduction, %for.innermost ] + %indvars.middle.next = add nuw nsw i64 %indvars.middle, 1 + %exitcond.middle = icmp ne i64 %indvars.middle.next, 10 + br i1 %exitcond.middle, label %for.middle.header, label %for.outermost.latch + +for.innermost: ; preds = %for.middle.header, %for.innermost + %xor.innermost = phi i32 [ %xor.middle, %for.middle.header ], [ %xor.reduction, %for.innermost ] + %indvars.innermost = phi i64 [ 0, %for.middle.header ], [ %indvars.innermost.next, %for.innermost ] + %arrayidx12 = getelementptr inbounds [10 x [10 x i32]], [10 x [10 x i32]]* %e, i64 %indvars.innermost, i64 %indvars.middle, i64 %indvars.outermost + %0 = load i32, i32* %arrayidx12 + %xor.reduction = add i32 %xor.innermost, %0 + %indvars.innermost.next = add nuw nsw i64 %indvars.innermost, 1 + %exitcond.innermost = icmp ne i64 %indvars.innermost.next, 10 + br i1 %exitcond.innermost, label %for.innermost, label %for.middle.latch } \ No newline at end of file