Index: llvm/include/llvm/Transforms/Utils/LoopUtils.h =================================================================== --- llvm/include/llvm/Transforms/Utils/LoopUtils.h +++ llvm/include/llvm/Transforms/Utils/LoopUtils.h @@ -562,6 +562,7 @@ PHINode *createWideIV(SCEVExpander &Rewriter); + PHINode *getOrigPhi() { return OrigPhi; }; unsigned getNumElimExt() { return NumElimExt; }; unsigned getNumWidened() { return NumWidened; }; Index: llvm/lib/Transforms/Scalar/LoopFlatten.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -45,6 +45,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #define DEBUG_TYPE "loop-flatten" @@ -75,6 +76,8 @@ BranchInst *OuterBranch = nullptr; SmallPtrSet LinearIVUses; SmallPtrSet InnerPHIsToTransform; + std::map Wide2OrigPHIs; + SmallPtrSet OrigPHIs; FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {}; }; @@ -204,6 +207,9 @@ // them specially when doing the transformation. if (&InnerPHI == FI.InnerInductionPHI) continue; + // Ignore the original phi, i.e. the phis that were widened. + if (FI.OrigPHIs.find(&InnerPHI) != FI.OrigPHIs.end() ) + continue; // Each inner loop PHI node must have two incoming values/blocks - one // from the pre-header, and one from the latch. @@ -249,6 +255,9 @@ } for (PHINode &OuterPHI : FI.OuterLoop->getHeader()->phis()) { + // Again, ignore the original phis. + if (FI.OrigPHIs.find(&OuterPHI) != FI.OrigPHIs.end()) + continue; if (!SafeOuterPHIs.count(&OuterPHI)) { LLVM_DEBUG(dbgs() << "found unsafe PHI in outer loop: "; OuterPHI.dump()); return false; @@ -322,6 +331,11 @@ // require a div/mod to reconstruct in the flattened loop, so the // transformation wouldn't be profitable. + Value *InnerLimit = FI.InnerLimit; + if (auto *I = dyn_cast(InnerLimit)) + InnerLimit = I->getOperand(0); + + // Check that all uses of the inner loop's induction variable match the // expected pattern, recording the uses of the outer IV. SmallPtrSet ValidOuterPHIUses; @@ -329,15 +343,29 @@ if (U == FI.InnerIncrement) continue; + // After widening the IVs, a trunc instruction might have been introduced, so + // look through truncs. + if (dyn_cast(U)) + U = *U->user_begin(); + LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump()); Value *MatchedMul, *MatchedItCount; - if (match(U, m_c_Add(m_Specific(FI.InnerInductionPHI), - m_Value(MatchedMul))) && - match(MatchedMul, - m_c_Mul(m_Specific(FI.OuterInductionPHI), - m_Value(MatchedItCount))) && - MatchedItCount == FI.InnerLimit) { + + bool IsAdd = match(U, m_c_Add(m_Specific(FI.InnerInductionPHI), + m_Value(MatchedMul))) && + match(MatchedMul, m_c_Mul(m_Specific(FI.OuterInductionPHI), + m_Value(MatchedItCount))); + + // Matches the same pattern as above, except it also looks for truncs + // on the phi, which can be the result of widening the induction variables. + bool IsAddTrunc = match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)), + m_Value(MatchedMul))) && + match(MatchedMul, + m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)), + m_Value(MatchedItCount))); + + if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerLimit) { LLVM_DEBUG(dbgs() << "Use is optimisable\n"); ValidOuterPHIUses.insert(MatchedMul); FI.LinearIVUses.insert(U); @@ -353,14 +381,26 @@ if (U == FI.OuterIncrement) continue; - LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump()); - - if (!ValidOuterPHIUses.count(U)) { - LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); - return false; - } else { + auto IsValidOuterPHIUses = [&] (User *U) -> bool { + LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump()); + if (!ValidOuterPHIUses.count(U)) { + LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); + return false; + } LLVM_DEBUG(dbgs() << "Use is optimisable\n"); + return true; + }; + + if (User *V = dyn_cast(U)) { + for (auto *K : V->users()) { + if (!IsValidOuterPHIUses(K)) + return false; + } + continue; } + + if (!IsValidOuterPHIUses(U)) + return false; } LLVM_DEBUG(dbgs() << "Found " << FI.LinearIVUses.size() @@ -414,10 +454,10 @@ return OverflowResult::MayOverflow; } -static bool FlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, - LoopInfo *LI, ScalarEvolution *SE, - AssumptionCache *AC, const TargetTransformInfo *TTI, - std::function markLoopAsDeleted) { +static bool CanFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, + LoopInfo *LI, ScalarEvolution *SE, + AssumptionCache *AC, const TargetTransformInfo *TTI, + std::function markLoopAsDeleted) { Function *F = FI.OuterLoop->getHeader()->getParent(); LLVM_DEBUG(dbgs() << "Loop flattening running on outer loop " @@ -463,28 +503,14 @@ if (!checkIVUsers(FI)) return false; - // Check if the new iteration variable might overflow. In this case, we - // need to version the loop, and select the original version at runtime if - // the iteration space is too large. - // TODO: We currently don't version the loop. - // TODO: it might be worth using a wider iteration variable rather than - // versioning the loop, if a wide enough type is legal. - bool MustVersionLoop = true; - OverflowResult OR = checkOverflow(FI, DT, AC); - if (OR == OverflowResult::AlwaysOverflowsHigh || - OR == OverflowResult::AlwaysOverflowsLow) { - LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n"); - return false; - } else if (OR == OverflowResult::MayOverflow) { - LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n"); - } else { - LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n"); - MustVersionLoop = false; - } + return true; +} - // We cannot safely flatten the loop. Exit now. - if (MustVersionLoop) - return false; +static bool DoFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, + LoopInfo *LI, ScalarEvolution *SE, + AssumptionCache *AC, const TargetTransformInfo *TTI, + std::function markLoopAsDeleted) { + Function *F = FI.OuterLoop->getHeader()->getParent(); // Do the actual transformation. LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n"); @@ -507,6 +533,12 @@ // Fix up PHI nodes that take values from the inner loop back-edge, which // we are about to remove. FI.InnerInductionPHI->removeIncomingValue(FI.InnerLoop->getLoopLatch()); + // The old Phi will be optimised away later, but for now we can't leave + // leave it in an invalid state, so are updating them too. + if (FI.Wide2OrigPHIs.find(FI.InnerInductionPHI) != FI.Wide2OrigPHIs.end()) { + auto *OrigPHI =FI.Wide2OrigPHIs.find(FI.InnerInductionPHI)->second; + OrigPHI->removeIncomingValue(FI.InnerLoop->getLoopLatch()); + } for (PHINode *PHI : FI.InnerPHIsToTransform) PHI->removeIncomingValue(FI.InnerLoop->getLoopLatch()); @@ -521,10 +553,21 @@ BranchInst::Create(InnerExitBlock, InnerExitingBlock); DT->deleteEdge(InnerExitingBlock, FI.InnerLoop->getHeader()); + auto HasSExtUser = [] (Value *V) -> Value * { + for (User *U : V->users() ) + if (dyn_cast(U)) + return U; + return nullptr; + }; + // Replace all uses of the polynomial calculated from the two induction // variables with the one new one. - for (Value *V : FI.LinearIVUses) + for (Value *V : FI.LinearIVUses) { + // If the induction variable has been widened, look through the SExt. + if (Value *U = HasSExtUser(V)) + V = U; V->replaceAllUsesWith(FI.OuterInductionPHI); + } // Tell LoopInfo, SCEV and the pass manager that the inner loop has been // deleted, and any information that have about the outer loop invalidated. @@ -532,10 +575,104 @@ SE->forgetLoop(FI.OuterLoop); SE->forgetLoop(FI.InnerLoop); LI->erase(FI.InnerLoop); + return true; +} + +static bool CanWidenIV(struct FlattenInfo &FI, DominatorTree *DT, + LoopInfo *LI, ScalarEvolution *SE, + AssumptionCache *AC, const TargetTransformInfo *TTI, + std::function markLoopAsDeleted) { + Module *M = FI.InnerLoop->getHeader()->getParent()->getParent(); + auto &DL = M->getDataLayout(); + auto *InnerType = FI.InnerInductionPHI->getType(); + auto *OuterType = FI.OuterInductionPHI->getType(); + unsigned MaxLegalSize = DL.getLargestLegalIntTypeSizeInBits(); + auto *MaxLegalType = DL.getLargestLegalIntType(M->getContext()); + + LLVM_DEBUG(dbgs() << "Try widening the IVs\n"); + // If both induction types are less than maximum integer width, promote + // both to the widest type available so we know calculating Limit * Limit + // as the new trip count is safe. + if (InnerType != OuterType || + InnerType->getScalarSizeInBits() == MaxLegalSize) { + LLVM_DEBUG(dbgs() << "Can't widen the IV\n"); + return false; + } + + SmallVector WideIVs; + auto AddCandidatePhi = [&] (BasicBlock::iterator I) { + for ( ; isa(I); ++I) { + LLVM_DEBUG(dbgs() << "Widen phi: "; cast(I)->dump()); + WideIVs.push_back( {cast(I), MaxLegalType, false }); + } + }; + + AddCandidatePhi(FI.InnerLoop->getHeader()->begin()); + AddCandidatePhi(FI.OuterLoop->getHeader()->begin()); + + if (WideIVs.empty()) + return false; + + SCEVExpander Rewriter(*SE, DL, "loopflatten"); + SmallVector DeadInsts; + + for (; !WideIVs.empty(); WideIVs.pop_back()) { + WidenIV Widener(WideIVs.back(), LI, SE, DT, DeadInsts, true, true); + if (PHINode *WidePhi = Widener.createWideIV(Rewriter)) { + LLVM_DEBUG(dbgs() << "Created wide phi: "; WidePhi->dump()); + auto *OrigPhi = Widener.getOrigPhi(); + FI.OrigPHIs.insert(OrigPhi); + FI.Wide2OrigPHIs.insert({ WidePhi, OrigPhi }); + } else { + return false; + } + } return true; } +static bool FlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, + LoopInfo *LI, ScalarEvolution *SE, + AssumptionCache *AC, const TargetTransformInfo *TTI, + std::function markLoopAsDeleted) { + Function *F = FI.OuterLoop->getHeader()->getParent(); + + LLVM_DEBUG(dbgs() << "Loop flattening running on outer loop " + << FI.OuterLoop->getHeader()->getName() << " and inner loop " + << FI.InnerLoop->getHeader()->getName() << " in " + << F->getName() << "\n"); + + SmallPtrSet IterationInstructions; + + if (!CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI, markLoopAsDeleted)) + return false; + + // Check if we can widen the induction variables to avoid overflow checks. + if (CanWidenIV(FI, DT, LI, SE, AC, TTI, markLoopAsDeleted)) { + // After widening, rediscover all the loop components. + if (!CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI, markLoopAsDeleted)) + return false; + return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, markLoopAsDeleted); + } + + // Check if the new iteration variable might overflow. In this case, we + // need to version the loop, and select the original version at runtime if + // the iteration space is too large. + // TODO: We currently don't version the loop. + OverflowResult OR = checkOverflow(FI, DT, AC); + if (OR == OverflowResult::AlwaysOverflowsHigh || + OR == OverflowResult::AlwaysOverflowsLow) { + LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n"); + return false; + } else if (OR == OverflowResult::MayOverflow) { + LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n"); + return false; + } + + LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n"); + return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, markLoopAsDeleted); +} + PreservedAnalyses LoopFlattenPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &Updater) { Index: llvm/test/Transforms/LoopFlatten/widen-iv.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/LoopFlatten/widen-iv.ll @@ -0,0 +1,106 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -S -loop-flatten -verify-loop-info -verify-dom-info -verify-scev -verify | FileCheck %s +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" + +; Function Attrs: nounwind +define dso_local void @foo(i32* %A, i32 %N, i32 %M) local_unnamed_addr #0 { +; CHECK-LABEL: @foo( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP17:%.*]] = icmp sgt i32 [[N:%.*]], 0 +; CHECK-NEXT: br i1 [[CMP17]], label [[FOR_COND1_PREHEADER_LR_PH:%.*]], label [[FOR_COND_CLEANUP:%.*]] +; CHECK: for.cond1.preheader.lr.ph: +; CHECK-NEXT: [[CMP215:%.*]] = icmp sgt i32 [[M:%.*]], 0 +; CHECK-NEXT: br i1 [[CMP215]], label [[FOR_COND1_PREHEADER_US_PREHEADER:%.*]], label [[FOR_COND1_PREHEADER_PREHEADER:%.*]] +; CHECK: for.cond1.preheader.preheader: +; CHECK-NEXT: br label [[FOR_COND1_PREHEADER:%.*]] +; CHECK: for.cond1.preheader.us.preheader: +; CHECK-NEXT: [[TMP0:%.*]] = sext i32 [[N]] to i64 +; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[M]] to i64 +; CHECK-NEXT: [[FLATTEN_TRIPCOUNT:%.*]] = mul i64 [[TMP1]], [[TMP0]] +; CHECK-NEXT: br label [[FOR_COND1_PREHEADER_US:%.*]] +; CHECK: for.cond1.preheader.us: +; CHECK-NEXT: [[INDVAR:%.*]] = phi i64 [ [[INDVAR_NEXT:%.*]], [[FOR_COND1_FOR_COND_CLEANUP3_CRIT_EDGE_US:%.*]] ], [ 0, [[FOR_COND1_PREHEADER_US_PREHEADER]] ] +; CHECK-NEXT: [[I_018_US:%.*]] = phi i32 [ [[INC6_US:%.*]], [[FOR_COND1_FOR_COND_CLEANUP3_CRIT_EDGE_US]] ], [ 0, [[FOR_COND1_PREHEADER_US_PREHEADER]] ] +; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[INDVAR]] to i32 +; CHECK-NEXT: [[MUL_US:%.*]] = mul nsw i32 [[TMP2]], [[M]] +; CHECK-NEXT: br label [[FOR_BODY4_US:%.*]] +; CHECK: for.body4.us: +; CHECK-NEXT: [[INDVAR1:%.*]] = phi i64 [ 0, [[FOR_COND1_PREHEADER_US]] ] +; CHECK-NEXT: [[J_016_US:%.*]] = phi i32 [ 0, [[FOR_COND1_PREHEADER_US]] ] +; CHECK-NEXT: [[TMP3:%.*]] = trunc i64 [[INDVAR1]] to i32 +; CHECK-NEXT: [[ADD_US:%.*]] = add nsw i32 [[TMP3]], [[MUL_US]] +; CHECK-NEXT: [[IDXPROM_US:%.*]] = sext i32 [[ADD_US]] to i64 +; CHECK-NEXT: [[ARRAYIDX_US:%.*]] = getelementptr inbounds i32, i32* [[A:%.*]], i64 [[INDVAR]] +; CHECK-NEXT: tail call void @f(i32* [[ARRAYIDX_US]]) +; CHECK-NEXT: [[INDVAR_NEXT2:%.*]] = add i64 [[INDVAR1]], 1 +; CHECK-NEXT: [[INC_US:%.*]] = add nuw nsw i32 [[J_016_US]], 1 +; CHECK-NEXT: [[CMP2_US:%.*]] = icmp slt i64 [[INDVAR_NEXT2]], [[TMP1]] +; CHECK-NEXT: br label [[FOR_COND1_FOR_COND_CLEANUP3_CRIT_EDGE_US]] +; CHECK: for.cond1.for.cond.cleanup3_crit_edge.us: +; CHECK-NEXT: [[INDVAR_NEXT]] = add i64 [[INDVAR]], 1 +; CHECK-NEXT: [[INC6_US]] = add nuw nsw i32 [[I_018_US]], 1 +; CHECK-NEXT: [[CMP_US:%.*]] = icmp slt i64 [[INDVAR_NEXT]], [[FLATTEN_TRIPCOUNT]] +; CHECK-NEXT: br i1 [[CMP_US]], label [[FOR_COND1_PREHEADER_US]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]] +; CHECK: for.cond1.preheader: +; CHECK-NEXT: [[I_018:%.*]] = phi i32 [ [[INC6:%.*]], [[FOR_COND1_PREHEADER]] ], [ 0, [[FOR_COND1_PREHEADER_PREHEADER]] ] +; CHECK-NEXT: [[INC6]] = add nuw nsw i32 [[I_018]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[INC6]], [[N]] +; CHECK-NEXT: br i1 [[CMP]], label [[FOR_COND1_PREHEADER]], label [[FOR_COND_CLEANUP_LOOPEXIT20:%.*]] +; CHECK: for.cond.cleanup.loopexit: +; CHECK-NEXT: br label [[FOR_COND_CLEANUP]] +; CHECK: for.cond.cleanup.loopexit20: +; CHECK-NEXT: br label [[FOR_COND_CLEANUP]] +; CHECK: for.cond.cleanup: +; CHECK-NEXT: ret void +; +entry: + %cmp17 = icmp sgt i32 %N, 0 + br i1 %cmp17, label %for.cond1.preheader.lr.ph, label %for.cond.cleanup + +for.cond1.preheader.lr.ph: ; preds = %entry + %cmp215 = icmp sgt i32 %M, 0 + br i1 %cmp215, label %for.cond1.preheader.us.preheader, label %for.cond1.preheader.preheader + +for.cond1.preheader.preheader: ; preds = %for.cond1.preheader.lr.ph + br label %for.cond1.preheader + +for.cond1.preheader.us.preheader: ; preds = %for.cond1.preheader.lr.ph + br label %for.cond1.preheader.us + +for.cond1.preheader.us: ; preds = %for.cond1.preheader.us.preheader, %for.cond1.for.cond.cleanup3_crit_edge.us + %i.018.us = phi i32 [ %inc6.us, %for.cond1.for.cond.cleanup3_crit_edge.us ], [ 0, %for.cond1.preheader.us.preheader ] + %mul.us = mul nsw i32 %i.018.us, %M + br label %for.body4.us + +for.body4.us: ; preds = %for.cond1.preheader.us, %for.body4.us + %j.016.us = phi i32 [ 0, %for.cond1.preheader.us ], [ %inc.us, %for.body4.us ] + %add.us = add nsw i32 %j.016.us, %mul.us + %idxprom.us = sext i32 %add.us to i64 + %arrayidx.us = getelementptr inbounds i32, i32* %A, i64 %idxprom.us + tail call void @f(i32* %arrayidx.us) #2 + %inc.us = add nuw nsw i32 %j.016.us, 1 + %cmp2.us = icmp slt i32 %inc.us, %M + br i1 %cmp2.us, label %for.body4.us, label %for.cond1.for.cond.cleanup3_crit_edge.us + +for.cond1.for.cond.cleanup3_crit_edge.us: ; preds = %for.body4.us + %inc6.us = add nuw nsw i32 %i.018.us, 1 + %cmp.us = icmp slt i32 %inc6.us, %N + br i1 %cmp.us, label %for.cond1.preheader.us, label %for.cond.cleanup.loopexit + +for.cond1.preheader: ; preds = %for.cond1.preheader.preheader, %for.cond1.preheader + %i.018 = phi i32 [ %inc6, %for.cond1.preheader ], [ 0, %for.cond1.preheader.preheader ] + %inc6 = add nuw nsw i32 %i.018, 1 + %cmp = icmp slt i32 %inc6, %N + br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup.loopexit20 + +for.cond.cleanup.loopexit: ; preds = %for.cond1.for.cond.cleanup3_crit_edge.us + br label %for.cond.cleanup + +for.cond.cleanup.loopexit20: ; preds = %for.cond1.preheader + br label %for.cond.cleanup + +for.cond.cleanup: ; preds = %for.cond.cleanup.loopexit20, %for.cond.cleanup.loopexit, %entry + ret void +} + +declare dso_local void @f(i32* %0) local_unnamed_addr #1