Index: llvm/lib/Transforms/Scalar/LoopFlatten.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -43,7 +43,10 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" +#include "llvm/Transforms/Utils/SimplifyIndVar.h" #define DEBUG_TYPE "loop-flatten" @@ -74,6 +77,8 @@ BranchInst *OuterBranch = nullptr; SmallPtrSet LinearIVUses; SmallPtrSet InnerPHIsToTransform; + std::map Wide2OrigPHIs; + SmallPtrSet OrigPHIs; FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {}; }; @@ -203,6 +208,9 @@ // them specially when doing the transformation. if (&InnerPHI == FI.InnerInductionPHI) continue; + // Ignore the original phi, i.e. the phis that were widened. + if (FI.OrigPHIs.find(&InnerPHI) != FI.OrigPHIs.end() ) + continue; // Each inner loop PHI node must have two incoming values/blocks - one // from the pre-header, and one from the latch. @@ -248,12 +256,16 @@ } for (PHINode &OuterPHI : FI.OuterLoop->getHeader()->phis()) { + // Again, ignore the original phis. + if (FI.OrigPHIs.find(&OuterPHI) != FI.OrigPHIs.end()) + continue; if (!SafeOuterPHIs.count(&OuterPHI)) { LLVM_DEBUG(dbgs() << "found unsafe PHI in outer loop: "; OuterPHI.dump()); return false; } } + LLVM_DEBUG(dbgs() << "checkPHIs: OK\n"); return true; } @@ -306,9 +318,12 @@ << RepeatedInstrCost << "\n"); // Bail out if flattening the loops would cause instructions in the outer // loop but not in the inner loop to be executed extra times. - if (RepeatedInstrCost > RepeatedInstructionThreshold) + if (RepeatedInstrCost > RepeatedInstructionThreshold) { + LLVM_DEBUG(dbgs() << "checkOuterLoopInsts: not profitable, bailing.\n"); return false; + } + LLVM_DEBUG(dbgs() << "checkOuterLoopInsts: OK\n"); return true; } @@ -321,6 +336,10 @@ // require a div/mod to reconstruct in the flattened loop, so the // transformation wouldn't be profitable. + Value *InnerLimit = FI.InnerLimit; + if (auto *I = dyn_cast(InnerLimit)) + InnerLimit = I->getOperand(0); + // Check that all uses of the inner loop's induction variable match the // expected pattern, recording the uses of the outer IV. SmallPtrSet ValidOuterPHIUses; @@ -328,15 +347,32 @@ if (U == FI.InnerIncrement) continue; + // After widening the IVs, a trunc instruction might have been introduced, so + // look through truncs. + if (dyn_cast(U) ) { + if (!U->hasOneUse()) + return false; + U = *U->user_begin(); + } + LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump()); - Value *MatchedMul, *MatchedItCount; - if (match(U, m_c_Add(m_Specific(FI.InnerInductionPHI), - m_Value(MatchedMul))) && - match(MatchedMul, - m_c_Mul(m_Specific(FI.OuterInductionPHI), - m_Value(MatchedItCount))) && - MatchedItCount == FI.InnerLimit) { + Value *MatchedMul; + Value *MatchedItCount; + bool IsAdd = match(U, m_c_Add(m_Specific(FI.InnerInductionPHI), + m_Value(MatchedMul))) && + match(MatchedMul, m_c_Mul(m_Specific(FI.OuterInductionPHI), + m_Value(MatchedItCount))); + + // Matches the same pattern as above, except it also looks for truncs + // on the phi, which can be the result of widening the induction variables. + bool IsAddTrunc = match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)), + m_Value(MatchedMul))) && + match(MatchedMul, + m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)), + m_Value(MatchedItCount))); + + if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerLimit) { LLVM_DEBUG(dbgs() << "Use is optimisable\n"); ValidOuterPHIUses.insert(MatchedMul); FI.LinearIVUses.insert(U); @@ -352,23 +388,35 @@ if (U == FI.OuterIncrement) continue; - LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump()); - - if (!ValidOuterPHIUses.count(U)) { - LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); - return false; - } else { + auto IsValidOuterPHIUses = [&] (User *U) -> bool { + LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump()); + if (!ValidOuterPHIUses.count(U)) { + LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); + return false; + } LLVM_DEBUG(dbgs() << "Use is optimisable\n"); + return true; + }; + + if (auto *V = dyn_cast(U)) { + for (auto *K : V->users()) { + if (!IsValidOuterPHIUses(K)) + return false; + } + continue; } + + if (!IsValidOuterPHIUses(U)) + return false; } - LLVM_DEBUG(dbgs() << "Found " << FI.LinearIVUses.size() + LLVM_DEBUG(dbgs() << "checkIVUsers: OK\n"; + dbgs() << "Found " << FI.LinearIVUses.size() << " value(s) that can be replaced:\n"; for (Value *V : FI.LinearIVUses) { dbgs() << " "; V->dump(); }); - return true; } @@ -413,15 +461,9 @@ return OverflowResult::MayOverflow; } -static bool FlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, - LoopInfo *LI, ScalarEvolution *SE, - AssumptionCache *AC, TargetTransformInfo *TTI) { - Function *F = FI.OuterLoop->getHeader()->getParent(); - LLVM_DEBUG(dbgs() << "Loop flattening running on outer loop " - << FI.OuterLoop->getHeader()->getName() << " and inner loop " - << FI.InnerLoop->getHeader()->getName() << " in " - << F->getName() << "\n"); - +static bool CanFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, + LoopInfo *LI, ScalarEvolution *SE, + AssumptionCache *AC, const TargetTransformInfo *TTI) { SmallPtrSet IterationInstructions; if (!findLoopComponents(FI.InnerLoop, IterationInstructions, FI.InnerInductionPHI, FI.InnerLimit, FI.InnerIncrement, FI.InnerBranch, SE)) @@ -459,32 +501,16 @@ if (!checkIVUsers(FI)) return false; - // Check if the new iteration variable might overflow. In this case, we - // need to version the loop, and select the original version at runtime if - // the iteration space is too large. - // TODO: We currently don't version the loop. - // TODO: it might be worth using a wider iteration variable rather than - // versioning the loop, if a wide enough type is legal. - bool MustVersionLoop = true; - OverflowResult OR = checkOverflow(FI, DT, AC); - if (OR == OverflowResult::AlwaysOverflowsHigh || - OR == OverflowResult::AlwaysOverflowsLow) { - LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n"); - return false; - } else if (OR == OverflowResult::MayOverflow) { - LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n"); - } else { - LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n"); - MustVersionLoop = false; - } - - // We cannot safely flatten the loop. Exit now. - if (MustVersionLoop) - return false; + LLVM_DEBUG(dbgs() << "CanFlattenLoopPair: OK\n"); + return true; +} - // Do the actual transformation. +static bool DoFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, + LoopInfo *LI, ScalarEvolution *SE, + AssumptionCache *AC, + const TargetTransformInfo *TTI) { + Function *F = FI.OuterLoop->getHeader()->getParent(); LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n"); - { using namespace ore; OptimizationRemark Remark(DEBUG_TYPE, "Flattened", FI.InnerLoop->getStartLoc(), @@ -503,6 +529,12 @@ // Fix up PHI nodes that take values from the inner loop back-edge, which // we are about to remove. FI.InnerInductionPHI->removeIncomingValue(FI.InnerLoop->getLoopLatch()); + // The old Phi will be optimised away later, but for now we can't leave + // leave it in an invalid state, so are updating them too. + if (FI.Wide2OrigPHIs.find(FI.InnerInductionPHI) != FI.Wide2OrigPHIs.end()) { + auto *OrigPHI = FI.Wide2OrigPHIs.find(FI.InnerInductionPHI)->second; + OrigPHI->removeIncomingValue(FI.InnerLoop->getLoopLatch()); + } for (PHINode *PHI : FI.InnerPHIsToTransform) PHI->removeIncomingValue(FI.InnerLoop->getLoopLatch()); @@ -517,10 +549,21 @@ BranchInst::Create(InnerExitBlock, InnerExitingBlock); DT->deleteEdge(InnerExitingBlock, FI.InnerLoop->getHeader()); + auto HasSExtUser = [] (Value *V) -> Value * { + for (User *U : V->users() ) + if (dyn_cast(U)) + return U; + return nullptr; + }; + // Replace all uses of the polynomial calculated from the two induction // variables with the one new one. - for (Value *V : FI.LinearIVUses) + for (Value *V : FI.LinearIVUses) { + // If the induction variable has been widened, look through the SExt. + if (Value *U = HasSExtUser(V)) + V = U; V->replaceAllUsesWith(FI.OuterInductionPHI); + } // Tell LoopInfo, SCEV and the pass manager that the inner loop has been // deleted, and any information that have about the outer loop invalidated. @@ -530,6 +573,85 @@ return true; } +static bool CanWidenIV(struct FlattenInfo &FI, DominatorTree *DT, + LoopInfo *LI, ScalarEvolution *SE, + AssumptionCache *AC, const TargetTransformInfo *TTI) { + LLVM_DEBUG(dbgs() << "Try widening the IVs\n"); + Module *M = FI.InnerLoop->getHeader()->getParent()->getParent(); + auto &DL = M->getDataLayout(); + auto *InnerType = FI.InnerInductionPHI->getType(); + auto *OuterType = FI.OuterInductionPHI->getType(); + unsigned MaxLegalSize = DL.getLargestLegalIntTypeSizeInBits(); + auto *MaxLegalType = DL.getLargestLegalIntType(M->getContext()); + + // If both induction types are less than the maximum legal integer width, + // promote both to the widest type available so we know calculating + // (OuterLimit * InnerLimit) as the new trip count is safe. + if (InnerType != OuterType || + InnerType->getScalarSizeInBits() >= MaxLegalSize || + MaxLegalType->getScalarSizeInBits() < InnerType->getScalarSizeInBits() * 2) { + LLVM_DEBUG(dbgs() << "Can't widen the IV\n"); + return false; + } + + SCEVExpander Rewriter(*SE, DL, "loopflatten"); + SmallVector WideIVs; + SmallVector DeadInsts; + WideIVs.push_back( {FI.InnerInductionPHI, MaxLegalType, false }); + WideIVs.push_back( {FI.OuterInductionPHI, MaxLegalType, false }); + FI.OrigPHIs.insert(FI.InnerInductionPHI); + FI.OrigPHIs.insert(FI.OuterInductionPHI); + unsigned ElimExt; + unsigned Widened; + + for (unsigned i = 0; i < WideIVs.size(); i++) { + PHINode *WidePhi = createWideIV(WideIVs[i], LI, SE, Rewriter, DT, DeadInsts, + ElimExt, Widened, true /* HasGuards */, + true /* UsePostIncrementRanges */); + if (!WidePhi) + return false; + LLVM_DEBUG(dbgs() << "Created wide phi: "; WidePhi->dump()); + FI.Wide2OrigPHIs.insert({ WidePhi, WideIVs[i].NarrowIV}); + } + // After widening, rediscover all the loop components. + return CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI); +} + +static bool FlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, + LoopInfo *LI, ScalarEvolution *SE, + AssumptionCache *AC, + const TargetTransformInfo *TTI) { + Function *F = FI.OuterLoop->getHeader()->getParent(); + LLVM_DEBUG(dbgs() << "Loop flattening running on outer loop " + << FI.OuterLoop->getHeader()->getName() << " and inner loop " + << FI.InnerLoop->getHeader()->getName() << " in " + << F->getName() << "\n"); + + if (!CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI)) + return false; + + // Check if we can widen the induction variables to avoid overflow checks. + if (CanWidenIV(FI, DT, LI, SE, AC, TTI)) + return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI); + + // Check if the new iteration variable might overflow. In this case, we + // need to version the loop, and select the original version at runtime if + // the iteration space is too large. + // TODO: We currently don't version the loop. + OverflowResult OR = checkOverflow(FI, DT, AC); + if (OR == OverflowResult::AlwaysOverflowsHigh || + OR == OverflowResult::AlwaysOverflowsLow) { + LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n"); + return false; + } else if (OR == OverflowResult::MayOverflow) { + LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n"); + return false; + } + + LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n"); + return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI); +} + bool Flatten(DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, TargetTransformInfo *TTI) { bool Changed = false; @@ -539,6 +661,10 @@ continue; struct FlattenInfo FI(OuterLoop, InnerLoop); Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI); + for (auto *Phi : FI.OrigPHIs) { + LLVM_DEBUG(dbgs()<< "Clean up old phi: "; Phi->dump()); + RecursivelyDeleteDeadPHINode(Phi); + } } return Changed; } Index: llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll =================================================================== --- llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll +++ llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll @@ -1,6 +1,8 @@ ; RUN: opt < %s -S -loop-flatten -debug-only=loop-flatten 2>&1 | FileCheck %s ; REQUIRES: asserts +target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64" + ; Every function in this file has a reason that it can't be transformed. ; CHECK-NOT: Checks all passed, doing the transformation Index: llvm/test/Transforms/LoopFlatten/widen-iv.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/LoopFlatten/widen-iv.ll @@ -0,0 +1,78 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -S -loop-flatten -verify-loop-info -verify-dom-info -verify-scev -verify | FileCheck %s +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" + +; Function Attrs: nounwind +define dso_local void @foo(i32* %A, i32 %N, i32 %M) local_unnamed_addr #0 { +; CHECK-LABEL: @foo( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP17:%.*]] = icmp sgt i32 [[N:%.*]], 0 +; CHECK-NEXT: br i1 [[CMP17]], label [[FOR_COND1_PREHEADER_LR_PH:%.*]], label [[FOR_COND_CLEANUP:%.*]] +; CHECK: for.cond1.preheader.lr.ph: +; CHECK-NEXT: [[CMP215:%.*]] = icmp sgt i32 [[M:%.*]], 0 +; CHECK-NEXT: br i1 [[CMP215]], label [[FOR_COND1_PREHEADER_US_PREHEADER:%.*]], label [[FOR_COND_CLEANUP]] +; CHECK: for.cond1.preheader.us.preheader: +; CHECK-NEXT: [[TMP0:%.*]] = sext i32 [[M]] to i64 +; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[N]] to i64 +; CHECK-NEXT: [[FLATTEN_TRIPCOUNT:%.*]] = mul i64 [[TMP0]], [[TMP1]] +; CHECK-NEXT: br label [[FOR_COND1_PREHEADER_US:%.*]] +; CHECK: for.cond1.preheader.us: +; CHECK-NEXT: [[INDVAR1:%.*]] = phi i64 [ [[INDVAR_NEXT2:%.*]], [[FOR_COND1_FOR_COND_CLEANUP3_CRIT_EDGE_US:%.*]] ], [ 0, [[FOR_COND1_PREHEADER_US_PREHEADER]] ] +; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[INDVAR1]] to i32 +; CHECK-NEXT: [[MUL_US:%.*]] = mul nsw i32 [[TMP2]], [[M]] +; CHECK-NEXT: br label [[FOR_BODY4_US:%.*]] +; CHECK: for.body4.us: +; CHECK-NEXT: [[INDVAR:%.*]] = phi i64 [ 0, [[FOR_COND1_PREHEADER_US]] ] +; CHECK-NEXT: [[TMP3:%.*]] = trunc i64 [[INDVAR]] to i32 +; CHECK-NEXT: [[ADD_US:%.*]] = add nsw i32 [[TMP3]], [[MUL_US]] +; CHECK-NEXT: [[IDXPROM_US:%.*]] = sext i32 [[ADD_US]] to i64 +; CHECK-NEXT: [[ARRAYIDX_US:%.*]] = getelementptr inbounds i32, i32* [[A:%.*]], i64 [[INDVAR1]] +; CHECK-NEXT: tail call void @f(i32* [[ARRAYIDX_US]]) +; CHECK-NEXT: [[INDVAR_NEXT:%.*]] = add i64 [[INDVAR]], 1 +; CHECK-NEXT: [[CMP2_US:%.*]] = icmp slt i64 [[INDVAR_NEXT]], [[TMP0]] +; CHECK-NEXT: br label [[FOR_COND1_FOR_COND_CLEANUP3_CRIT_EDGE_US]] +; CHECK: for.cond1.for.cond.cleanup3_crit_edge.us: +; CHECK-NEXT: [[INDVAR_NEXT2]] = add i64 [[INDVAR1]], 1 +; CHECK-NEXT: [[CMP_US:%.*]] = icmp slt i64 [[INDVAR_NEXT2]], [[FLATTEN_TRIPCOUNT]] +; CHECK-NEXT: br i1 [[CMP_US]], label [[FOR_COND1_PREHEADER_US]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]] +; CHECK: for.cond.cleanup.loopexit: +; CHECK-NEXT: br label [[FOR_COND_CLEANUP]] +; CHECK: for.cond.cleanup: +; CHECK-NEXT: ret void +; +entry: + %cmp17 = icmp sgt i32 %N, 0 + br i1 %cmp17, label %for.cond1.preheader.lr.ph, label %for.cond.cleanup + +for.cond1.preheader.lr.ph: + %cmp215 = icmp sgt i32 %M, 0 + br i1 %cmp215, label %for.cond1.preheader.us.preheader, label %for.cond.cleanup + +for.cond1.preheader.us.preheader: + br label %for.cond1.preheader.us + +for.cond1.preheader.us: + %i.018.us = phi i32 [ %inc6.us, %for.cond1.for.cond.cleanup3_crit_edge.us ], [ 0, %for.cond1.preheader.us.preheader ] + %mul.us = mul nsw i32 %i.018.us, %M + br label %for.body4.us + +for.body4.us: + %j.016.us = phi i32 [ 0, %for.cond1.preheader.us ], [ %inc.us, %for.body4.us ] + %add.us = add nsw i32 %j.016.us, %mul.us + %idxprom.us = sext i32 %add.us to i64 + %arrayidx.us = getelementptr inbounds i32, i32* %A, i64 %idxprom.us + tail call void @f(i32* %arrayidx.us) #2 + %inc.us = add nuw nsw i32 %j.016.us, 1 + %cmp2.us = icmp slt i32 %inc.us, %M + br i1 %cmp2.us, label %for.body4.us, label %for.cond1.for.cond.cleanup3_crit_edge.us + +for.cond1.for.cond.cleanup3_crit_edge.us: + %inc6.us = add nuw nsw i32 %i.018.us, 1 + %cmp.us = icmp slt i32 %inc6.us, %N + br i1 %cmp.us, label %for.cond1.preheader.us, label %for.cond.cleanup + +for.cond.cleanup: + ret void +} + +declare dso_local void @f(i32* %0) local_unnamed_addr #1