Index: include/llvm/InitializePasses.h =================================================================== --- include/llvm/InitializePasses.h +++ include/llvm/InitializePasses.h @@ -142,6 +142,7 @@ void initializeIPSCCPPass(PassRegistry&); void initializeIVUsersPass(PassRegistry&); void initializeIfConverterPass(PassRegistry&); +void initializeInductiveRangeCheckEliminationPass(PassRegistry&); void initializeIndVarSimplifyPass(PassRegistry&); void initializeInlineCostAnalysisPass(PassRegistry&); void initializeInstCombinerPass(PassRegistry&); Index: include/llvm/LinkAllPasses.h =================================================================== --- include/llvm/LinkAllPasses.h +++ include/llvm/LinkAllPasses.h @@ -86,6 +86,7 @@ (void) llvm::createGlobalsModRefPass(); (void) llvm::createIPConstantPropagationPass(); (void) llvm::createIPSCCPPass(); + (void) llvm::createInductiveRangeCheckEliminationPass(); (void) llvm::createIndVarSimplifyPass(); (void) llvm::createInstructionCombiningPass(); (void) llvm::createInternalizePass(); Index: include/llvm/Transforms/Scalar.h =================================================================== --- include/llvm/Transforms/Scalar.h +++ include/llvm/Transforms/Scalar.h @@ -98,6 +98,13 @@ //===----------------------------------------------------------------------===// // +// InductiveRangeCheckElimination - Transform loops to elide range checks on +// linear functions of the induction variable. +// +Pass *createInductiveRangeCheckEliminationPass(); + +//===----------------------------------------------------------------------===// +// // InductionVariableSimplify - Transform induction variables in a program to all // use a single canonical induction variable per loop. // Index: lib/Transforms/Scalar/CMakeLists.txt =================================================================== --- lib/Transforms/Scalar/CMakeLists.txt +++ lib/Transforms/Scalar/CMakeLists.txt @@ -9,6 +9,7 @@ EarlyCSE.cpp FlattenCFGPass.cpp GVN.cpp + InductiveRangeCheckElimination.cpp IndVarSimplify.cpp JumpThreading.cpp LICM.cpp Index: lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp =================================================================== --- /dev/null +++ lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -0,0 +1,874 @@ +//===-- InductiveRangeCheckElimination.cpp - ------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// The InductiveRangeCheckElimination pass splits a loop's iteration space into +// three disjoint ranges. It does that in a way such that the loop running in +// the middle range provably does not need range checks. +//===----------------------------------------------------------------------===// + + +#include "llvm/ADT/Optional.h" + +#include "llvm/Analysis/AssumptionTracker.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/ValueTracking.h" + +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/IR/Verifier.h" + +#include "llvm/Support/Debug.h" + +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/SimplifyIndVar.h" +#include "llvm/Transforms/Utils/UnrollLoop.h" + +#include "llvm/Pass.h" + +using namespace llvm; + +cl::opt LoopSizeCutoff("irce-loop-size-cutoff", cl::Hidden, cl::init(64)); + +#define DEBUG_TYPE "irce" + +namespace { + +/// An inductive range check is a branch that is conditional on an expression of +/// the form +/// +/// 0 <= (Offset + Scale * I) < Length +/// +/// where `I' is the canonical induction variable of a loop to which Offset and +/// Scale are loop invariant, and Length is >= 0. + +class InductiveRangeCheck { + const SCEV *Offset; + const SCEV *Scale; + Value *Length; + AssertingVH Branch; + + InductiveRangeCheck() { } + + public: + const SCEV *getOffset() const { return Offset; } + const SCEV *getScale() const { return Scale; } + Value *getLength() const { return Length; } + + void dump() const { + dbgs() << "InductiveRangeCheck:\n"; + dbgs() << " Offset: "; Offset->print(dbgs()); + dbgs() << " Scale: "; Scale->print(dbgs()); + dbgs() << " Length: "; Length->dump(); + dbgs() << " Branch: "; getBranch()->dump(); + } + + BranchInst *getBranch() const { return Branch; } + + /// Represents an integer range [Range.first, Range.second). If Range.second + /// < Range.first, then the value denotes the empty range. + typedef std::pair Range; + typedef SpecificBumpPtrAllocator AllocatorTy; + + /// Computes a range for the induction variable in which the range check is + /// redundant and can be constant-folded away. + Optional computeSafeIterationSpace(ScalarEvolution &SE, + Instruction *Loc) const; + + /// Create an inductive range check out of BI if possible, else return + /// nullptr. + static InductiveRangeCheck *create(AllocatorTy &Alloc, BranchInst *BI, + Loop *L, ScalarEvolution &SE); +}; + + +class InductiveRangeCheckElimination : public LoopPass { + InductiveRangeCheck::AllocatorTy Allocator; + + public: + static char ID; + InductiveRangeCheckElimination() : LoopPass(ID) { + initializeInductiveRangeCheckEliminationPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.addPreserved(); + + // TODO: we can preserve this + AU.addRequiredID(LoopSimplifyID); + + AU.addRequiredID(LCSSAID); + AU.addPreservedID(LCSSAID); + + AU.addRequired(); + AU.addPreserved(); + + AU.addRequired(); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override; +}; + +char InductiveRangeCheckElimination::ID = 0; + +} + +INITIALIZE_PASS(InductiveRangeCheckElimination, "irce", + "Inductive range check elimination", false, false) + +/// Split a condition into something semantically equivalent to (0 <= I < +/// Limit), both comparisons signed and Len loop invariant on L and positive. +/// On success, return true and set Index to I and UpperLimit to Limit. Return +/// false on failure (we may still write to UpperLimit and Index on failure). + +static bool SplitRangeCheckCondition(Loop *L, ScalarEvolution &SE, + Value *Condition, const SCEV *&Index, + Value *&UpperLimit) { + + // TODO: currently this catches some silly cases like comparing "%idx slt 1". + // Our transformations are still correct, but less likely to be profitable in + // those cases. We have to come up with some heuristics that pick out the + // range checks that are more profitable to clone a loop for. + + using namespace llvm::PatternMatch; + + Value *A = nullptr; + Value *B = nullptr; + ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; + + // In these early checks we assume that the matched UpperLimit is positive. + // We'll verify that fact later, before returning true. + + if (match(Condition, m_And(m_Value(A), m_Value(B)))) { + Value *IndexV = nullptr; + auto MatchLowerBoundCheck = m_CombineOr( + m_ICmpWithPred( + m_Value(IndexV), m_ConstantInt<0>()), + m_ICmpWithPred( + m_Value(IndexV), m_ConstantInt<-1>())); + + Value *ExpectedUpperBoundCheck = nullptr; + + if (match(A, MatchLowerBoundCheck)) + ExpectedUpperBoundCheck = B; + else if (match(B, MatchLowerBoundCheck)) + ExpectedUpperBoundCheck = A; + else + return false; + + auto MatchUpperBoundCheck = m_CombineOr( + m_ICmpWithPred( + m_Specific(IndexV), m_Value(UpperLimit)), + m_ICmpWithPred( + m_Specific(IndexV), m_Value(UpperLimit))); + + if (!match(ExpectedUpperBoundCheck, MatchUpperBoundCheck)) + return false; + + Index = SE.getSCEV(IndexV); + + if (isa(Index)) + return false; + + } else if (match(Condition, m_ICmp(Pred, m_Value(A), m_Value(B)))) { + switch (Pred) { + default: + return false; + + case ICmpInst::ICMP_SGT: + std::swap(A, B); + // fall through + case ICmpInst::ICMP_SLT: + UpperLimit = B; + Index = SE.getSCEV(A); + if (isa(Index) || !SE.isKnownNonNegative(Index)) + return false; + break; + + case ICmpInst::ICMP_UGT: + std::swap(A, B); + // fall through + case ICmpInst::ICMP_ULT: + UpperLimit = B; + Index = SE.getSCEV(A); + if (isa(Index)) + return false; + break; + } + } else { + return false; + } + + const SCEV *UpperLimitSCEV = SE.getSCEV(UpperLimit); + if (isa(UpperLimitSCEV) || !SE.isKnownNonNegative(UpperLimitSCEV)) + return false; + + if (SE.getLoopDisposition(UpperLimitSCEV, L) != ScalarEvolution::LoopInvariant) { + DEBUG( + dbgs() << " in function: " << L->getHeader()->getParent()->getParent()->getModuleIdentifier() << " "; + dbgs() << " could not make length loop invariant: " << UpperLimit->getName() << "\n";); + return false; + } + + return true; +} + + +InductiveRangeCheck *InductiveRangeCheck::create( + InductiveRangeCheck::AllocatorTy &A, BranchInst *BI, Loop *L, + ScalarEvolution &SE) { + + if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch()) + return nullptr; + + Value *Length = nullptr; + const SCEV *IndexSCEV = nullptr; + + if (!SplitRangeCheckCondition(L, SE, BI->getCondition(), IndexSCEV, Length)) + return nullptr; + + assert(IndexSCEV && Length && "contract with SplitRangeCheckCondition!"); + + const SCEVAddRecExpr *IndexAddRec = dyn_cast(IndexSCEV); + + if (!(IndexAddRec && IndexAddRec->getLoop() == L && IndexAddRec->isAffine())) + return nullptr; + + InductiveRangeCheck *IRC = new (A.Allocate()) InductiveRangeCheck; + IRC->Length = Length; + IRC->Offset = IndexAddRec->getStart(); + IRC->Scale = IndexAddRec->getStepRecurrence(SE); + IRC->Branch = BI; + return IRC; +} + + +static Value *ConstructSMinOf(Value *A, Value *B, Instruction *InsertPt) { + ICmpInst *Cmp = new ICmpInst(InsertPt, ICmpInst::ICMP_SLT, A, B); + return SelectInst::Create(Cmp, A, B, "", InsertPt); +}; + + +static Value *ConstructSMaxOf(Value *A, Value *B, Instruction *InsertPt) { + ICmpInst *Cmp = new ICmpInst(InsertPt, ICmpInst::ICMP_SGT, A, B); + return SelectInst::Create(Cmp, A, B, "", InsertPt); +}; + + +/// Clone the body of loop L into three sub-loops. +static void ClonePreAndPostLoopBodies( + Function *Fn, Loop *L, ValueToValueMapTy &PreLoopMap, + SmallVectorImpl &PreLoopBlocks, + ValueToValueMapTy &PostLoopMap, + SmallVectorImpl &PostLoopBlocks) { + + for (BasicBlock *BB : L->getBlocks()) { + BasicBlock *PreClone = CloneBasicBlock(BB, PreLoopMap, ".preloop", Fn); + PreLoopBlocks.push_back(PreClone); + PreLoopMap[BB] = PreClone; + + BasicBlock *PostClone = CloneBasicBlock(BB, PostLoopMap, ".postloop", Fn); + PostLoopBlocks.push_back(PostClone); + PostLoopMap[BB] = PostClone; + } +} + +/// Edit PHI nodes in all the exit blocks for the 2 * N new incoming edges from +/// the pre and post loops. +static void UpdateExitBlockPHIs( + Loop *L, ValueToValueMapTy &PreLoopMap, + const SmallVectorImpl &PreLoopBlocks, + ValueToValueMapTy &PostLoopMap, + const SmallVectorImpl &PostLoopBlocks) { + + for (size_t i = 0, e = PreLoopBlocks.size(); i != e; ++i) { + BasicBlock *PreBB = PreLoopBlocks[i]; + BasicBlock *PostBB = PostLoopBlocks[i]; + + for (Instruction &I : *PreBB) + RemapInstruction(&I, PreLoopMap, RF_NoModuleLevelChanges|RF_IgnoreMissingEntries); + + for (Instruction &I : *PostBB) + RemapInstruction(&I, PostLoopMap, RF_NoModuleLevelChanges|RF_IgnoreMissingEntries); + + // Since we're splitting up the loop body, exit blocks will now + // have two more predecessors and their PHI nodes need to be + // edited. No phi nodes need to be introduced because the loop is + // in LCSSA. + + BasicBlock *OrigBB = L->getBlocks()[i]; + if (L->isLoopExiting(OrigBB)) { + for (auto ExitBBI = succ_begin(OrigBB), ExitBBE = succ_end(OrigBB); + ExitBBI != ExitBBE; + ++ExitBBI) { + + if (L->contains(*ExitBBI)) continue; + + for (Instruction &I : **ExitBBI) { + if (!isa(&I)) break; + + PHINode *PN = cast(&I); + Value *OldIncoming = PN->getIncomingValueForBlock(OrigBB); + + assert((PostLoopMap.find(OldIncoming) == PostLoopMap.end()) == + (PreLoopMap.find(OldIncoming) == PreLoopMap.end()) && + "PreBB and PostBB are clones!"); + + if (PreLoopMap.count(OldIncoming)) { + PN->addIncoming(PreLoopMap[OldIncoming], PreBB); + PN->addIncoming(PostLoopMap[OldIncoming], PostBB); + } else { + PN->addIncoming(OldIncoming, PreBB); + PN->addIncoming(OldIncoming, PostBB); + } + } + } + } + } +} + + +/// Update PHI nodes in Header: it still has one predecessor (its selector), but +/// that predecessor itself has two possible incoming values for each PHI node. +/// Callback is invoked on each pair (existing phi, new selector phi) +static void UpdatePHINodesInHeader( + BasicBlock *Header, BasicBlock *MainLoopSelector, + BasicBlock *Preheader, BasicBlock *Latch, + BasicBlock *PreLoopExit, const ValueToValueMapTy &PreLoopMap, + const std::function &Callback) { + + for (Instruction &HI : *Header) { + PHINode *PN = dyn_cast(&HI); + if (!PN) return; + + PHINode *SelectorPHI = + PHINode::Create(PN->getType(), 2, "selector.phi", MainLoopSelector->begin()); + + // Value on first iteration -- whatever we already had + SelectorPHI->addIncoming(PN->getIncomingValueForBlock(Preheader), Preheader); + + // N'th iteration -- it is like we're taking the backedge, but + // we need to use the value computed by the pre loop instead of + // the value we've computed. + Value *SelfBackedgeValue = PN->getIncomingValueForBlock(Latch); + auto It = PreLoopMap.find(SelfBackedgeValue); + + if (It != PreLoopMap.end()) + SelectorPHI->addIncoming(It->second, PreLoopExit); + else + SelectorPHI->addIncoming(SelfBackedgeValue, PreLoopExit); + + assert(PN->getNumIncomingValues() == 2 && "exactly one latch!"); + + int Idx = PN->getBasicBlockIndex(Preheader); + assert(Idx >= 0 && "no branch from preheader?"); + + PN->setIncomingBlock(Idx, MainLoopSelector); + PN->setIncomingValue(Idx, SelectorPHI); + + Callback(PN, SelectorPHI); + } +} + + +/// Update PHI nodes in the postloop. PHI nodes had an incoming edge from its +/// own Latch and one from Preheader. The former should have been updated by +/// RemapInstruction. Here we update the latter to come in from MainLoopExit, +/// and add a new edge coming in from the MainLoopSelector. The post loop does +/// not have a selector. + +static void UpdatePHINodesInPostLoop(BasicBlock *Header, BasicBlock *Preheader, + BasicBlock *Latch, BasicBlock *MainLoopExit, + BasicBlock *MainLoopSelector, + BasicBlock *PostLoopHeader, + ValueToValueMapTy &PostLoopMap) { + auto MainLoopHeaderIt = Header->begin(); + for (Instruction &PI : *PostLoopHeader) { + PHINode *PN = dyn_cast(&PI); + if (!PN) break; + + assert(PN->getNumIncomingValues() == 2 && "exactly one latch!"); + + int Idx = PN->getBasicBlockIndex(Preheader); + assert(Idx >= 0 && "no incoming value for Preheader?"); + + PN->setIncomingBlock(Idx, MainLoopExit); + + PHINode *MainLoopPN = cast(MainLoopHeaderIt++); + assert(PostLoopMap[MainLoopPN] == PN && "cloned out of order?"); + + Value *NextValue = MainLoopPN->getIncomingValueForBlock(Latch); + PN->setIncomingValue(Idx, NextValue); + + PN->addIncoming(MainLoopPN->getIncomingValueForBlock(MainLoopSelector), + MainLoopSelector); + } +} + + +/// ConstrainLoopRange splits loop L into three loops: the pre loop, the "main" +/// loop and the post loop in a way that the main loop executes iterations where +/// the value of the induction variable is a subset of Range. The pre and post +/// loops execute the remaining iterations. Returns false on failure (no +/// changes are made to the CFG in that case) and true on success. + +static bool ConstrainLoopRange(const InductiveRangeCheck::Range &Range, + Loop *L, ScalarEvolution &SE, LPPassManager &LPM, + LoopInfo *LI, AssumptionTracker *AT, + Pass *P) { + + using namespace llvm::PatternMatch; + + if (L->getBlocks().size() >= LoopSizeCutoff) { + DEBUG(dbgs() << "irce: giving up constraining loop, too large\n";); + return false; + } + + assert(L->isLoopSimplifyForm() && "should follow from addRequired<>"); + + BasicBlock *Latch = L->getLoopLatch(); + if (!L->isLoopExiting(Latch)) { + DEBUG(dbgs() << "irce: giving up constraining loop, no loop latch\n";); + return false; + } + + PHINode *CIV = L->getCanonicalInductionVariable(); + if (!CIV) { + DEBUG(dbgs() << "irce: giving up constraining loop, no CIV\n";); + return false; + } + + Value *CIVNext = CIV->getIncomingValueForBlock(Latch); + BasicBlock *Header = L->getHeader(); + BasicBlock *Preheader = L->getLoopPreheader(); + + assert(Latch && Preheader && "supposed to be in LoopSimplify form"); + + BranchInst *LatchBr = dyn_cast(Latch->getTerminator()); + const SCEV *LatchCountSCEV = SE.getExitCount(L, Latch); + if (isa(LatchCountSCEV)) { + DEBUG(dbgs() << "irce: giving up constraining loop, could not compute latch count\n";); + return false; + } + + assert(SE.getLoopDisposition(LatchCountSCEV, L) == ScalarEvolution::LoopInvariant && + "loop variant exit count doesn't make sense!"); + + // While SCEV does most of the analysis for us, we still have to + // modify the latch; and currently we can only deal with certain + // kinds of latches. This can be made more sophisticated as needed. + + if (!LatchBr || LatchBr->isUnconditional()) { + DEBUG(dbgs() << "irce: giving up constraining loop, latch terminator not conditional branch\n";); + return false; + } + + // The LatchBrExitIdx'th successor of the latch block exits the loop + // by branching into LatchExit which isn't in the loop. + + unsigned LatchBrExitIdx = -1; + BasicBlock *LatchExit = nullptr; + { + // Currently we only support a latch condition of the form: + // + // %condition = icmp slt %civNext, %limit + // br i1 %condition, label %header, label %exit + + if (LatchBr->getSuccessor(0) != Header) return false; + + Value *CIVComparedTo = nullptr; + if (!match(LatchBr->getCondition(), m_ICmpWithPred( + m_Specific(CIVNext), m_Value(CIVComparedTo)))) { + DEBUG(dbgs() << "irce: giving up constraining loop, unknown latch form\n";); + return false; + } + + const SCEV *CIVComparedToSCEV = SE.getSCEV(CIVComparedTo); + if (isa(CIVComparedToSCEV)) { + DEBUG(dbgs() << "irce: giving up constraining loop, could not relate CIV to latch expression\n";); + return false; + } + + const SCEV *ShouldBeOne = SE.getMinusSCEV(CIVComparedToSCEV, LatchCountSCEV); + const SCEVConstant *SCEVOne = dyn_cast(ShouldBeOne); + if (!SCEVOne || SCEVOne->getValue()->getValue() != 1) { + DEBUG(dbgs() << "irce: giving up constraining loop, unexpected header count in latch\n";); + return false; + } + + LatchBrExitIdx = 1; + LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); + } + + assert(!L->contains(LatchExit) && "expected an exit block!"); + + Function *Fn = Latch->getParent(); + LLVMContext &Ctx = Fn->getContext(); + + // Once the the induction variable reaches this value, the control + // leaves the pre loop. + Value *ExitPreLoopAt = nullptr; + + // Once the the induction variable reaches this value, the control + // leaves the main loop. + Value *ExitMainLoopAt = nullptr; + + // LatchCount + 1. The latch branches back to the header if + // "next(IV) slt HeaderCount". We've verfied this earlier. + Value *HeaderCount = nullptr; + + { + IntegerType *Ty = cast(LatchCountSCEV->getType()); + + SCEVExpander Expander(SE, "irce"); + Instruction *InsertPt = Preheader->getTerminator(); + + ConstantInt *One = ConstantInt::get(Ty, 1); + Value *LatchCount = Expander.expandCodeFor(LatchCountSCEV, Ty, InsertPt); + + // I think we can be more aggressive here and make this nuw / nsw + // if the addition that feeds into the icmp for the latch's + // terminating branch is nuw / nsw. In any case, a wrapping 2's + // complement addition is safe. + HeaderCount = BinaryOperator::Create( + BinaryOperator::Add, LatchCount, One, "hdr.count", InsertPt); + + ExitPreLoopAt = ConstructSMinOf(HeaderCount, Range.first, InsertPt); + ExitMainLoopAt = ConstructSMinOf(HeaderCount, Range.second, InsertPt); + } + + SmallVector PreLoopBlocks, PostLoopBlocks; + ValueToValueMapTy PreLoopMap, PostLoopMap; + + ClonePreAndPostLoopBodies(Fn, L, PreLoopMap, PreLoopBlocks, PostLoopMap, + PostLoopBlocks); + UpdateExitBlockPHIs(L, PreLoopMap, PreLoopBlocks, PostLoopMap, + PostLoopBlocks); + + IRBuilder<> B(Ctx); + + BasicBlock *PreLoopLatch = cast(PreLoopMap[Latch]); + BasicBlock *PreLoopHeader = cast(PreLoopMap[Header]); + +// The pre loop and the main loop have a "selector block" that +// decides whether the next iteration (and the maybe the ones after +// that) should continue in that loop, or if it should "skip" on to +// the next loop. The selector block is entered only if there is at +// least one more iteration to be executed. The post loop does not +// have a selector block since it unconditionally executes the +// remaining iterations. + + BasicBlock *MainLoopSelector = + BasicBlock::Create(Ctx, "mainloop.selector", Fn, PreLoopHeader); + // The unreachable will be replaced later. + new UnreachableInst(Ctx, MainLoopSelector); + + { + // Once we're in the preheader of the original loop, we know it is + // safe to execute at least one iteration in one of the three + // loops we have. We change the preheader to become the pre + // loop's selector, which determines if we need to run in the pre + // loop at all. Note that it means the preheader of the original + // loop ceases to remain a preheader. PHI nodes remain unchanged. + + assert(cast(CIV->getIncomingValueForBlock(Preheader))->getValue() == 0 && + "should be true for canonical induction variables!"); + + ConstantInt *Zero = ConstantInt::get(cast(CIV->getType()), 0); + TerminatorInst *OldPreheaderBr = Preheader->getTerminator(); + OldPreheaderBr->setSuccessor(0, PreLoopHeader); + + B.SetInsertPoint(OldPreheaderBr); + + // Do we have to run the iteration where CIV is Zero in the + // pre loop? + Value *ZeroIsInRange = B.CreateICmpSLT(Zero, Range.first); + B.CreateCondBr(ZeroIsInRange, PreLoopHeader, MainLoopSelector); + OldPreheaderBr->eraseFromParent(); + } + + BasicBlock *PreLoopExit = BasicBlock::Create(Ctx, "preloop.exit", Fn, Header); + Value *PreLoopCIVNext = PreLoopMap[CIVNext]; + + { + // The pre loop latch needs to be changed from running the full + // loop to jumping out as soon as it is okay to (possibly) start + // running the main loop. + + TerminatorInst *PreLoopLatchBr = PreLoopLatch->getTerminator(); + + B.SetInsertPoint(PreLoopLatchBr); + B.CreateCondBr(B.CreateICmpSLT(PreLoopCIVNext, ExitPreLoopAt), + PreLoopHeader, PreLoopExit); + PreLoopLatchBr->eraseFromParent(); + } + + { + B.SetInsertPoint(PreLoopExit); + + // Did we run out of iterations? If so, jump to the "original" + // loop exit. Otherwise continue on to the main loop selector. + + B.CreateCondBr(B.CreateICmpSLT(PreLoopCIVNext, HeaderCount), + MainLoopSelector, LatchExit); + } + + // The dynamic value of the induction variable in the main loop + // selector + PHINode *IVarInMainLoopSelector = nullptr; + auto RememberCIVCopy = [&IVarInMainLoopSelector, CIV]( + PHINode *PN, PHINode *SelectorPHI) { + if (PN == CIV) IVarInMainLoopSelector = SelectorPHI; + }; + + UpdatePHINodesInHeader(Header, MainLoopSelector, Preheader, Latch, + PreLoopExit, PreLoopMap, RememberCIVCopy); + + BasicBlock *PostLoopHeader = cast(PostLoopMap[Header]); + + { + // We know that we have to run at least one more iteration given the + // original limits of the loop. Here we decide if it is okay to run some of + // those iterations in the main loop. + + UnreachableInst *UI = cast(MainLoopSelector->getTerminator()); + B.SetInsertPoint(UI); + B.CreateCondBr(B.CreateAnd(B.CreateICmpSLE(Range.first, IVarInMainLoopSelector), + B.CreateICmpSLT(IVarInMainLoopSelector, Range.second)), + Header, PostLoopHeader); + UI->eraseFromParent(); + } + + BasicBlock *MainLoopExit = BasicBlock::Create(Ctx, "mainloop.exit", Fn, PostLoopHeader); + + { + // The main loop should only execute iterations where the canonical IV is in + // Range. + + B.SetInsertPoint(LatchBr); + + B.CreateCondBr(B.CreateICmpSLT(CIVNext, ExitMainLoopAt), Header, + MainLoopExit); + LatchBr->eraseFromParent(); + LatchBr = nullptr; + } + + { + B.SetInsertPoint(MainLoopExit); + + // Did we run out of iterations? If so, jump to the "original" + // loop exit. Otherwise continue on to the post loop header + B.CreateCondBr(B.CreateICmpSLT(CIVNext, HeaderCount), PostLoopHeader, LatchExit); + } + + UpdatePHINodesInPostLoop(Header, Preheader, Latch, MainLoopExit, + MainLoopSelector, PostLoopHeader, PostLoopMap); + + // Update PHI nodes in LatchExit to reflect new control flow. We already did + // part of the work when we fixed up exit blocks. The pre and main loop no + // longer directly branch here from their latch. + + for (Instruction &LEI : *LatchExit) { + PHINode *PN = dyn_cast(&LEI); + if (!PN) break; + + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + BasicBlock *BB = PN->getIncomingBlock(i); + if (BB == PreLoopLatch) { + PN->setIncomingBlock(i, PreLoopExit); + } else if (BB == Latch) { + PN->setIncomingBlock(i, MainLoopExit); + } + } + } + + Loop *PreLoop = nullptr; + Loop *PostLoop = nullptr; + + { + // Preserve the analysis passes that we said we'd preserve + + PreLoop = LPM.cloneLoop(L, L->getParentLoop(), PreLoopMap, LI); + PostLoop = LPM.cloneLoop(L, L->getParentLoop(), PostLoopMap, LI); + + if (Loop *PL = L->getParentLoop()) { + auto &LoopInfoBase = LI->getBase(); + PL->addBasicBlockToLoop(MainLoopExit, LoopInfoBase); + PL->addBasicBlockToLoop(PreLoopExit, LoopInfoBase); + PL->addBasicBlockToLoop(MainLoopSelector, LoopInfoBase); + } + + // Backedge counts have changed! + for (Loop *LoopIt = L; LoopIt; LoopIt = LoopIt->getParentLoop()) + SE.forgetLoop(LoopIt); + } + + (void) PreLoop; + (void) PostLoop; +#ifndef NDEBUG + PreLoop->verifyLoop(); + PostLoop->verifyLoop(); + L->verifyLoop(); + if (Loop *PL = L->getParentLoop()) + PL->verifyLoop(); +#endif + + assert(!verifyFunction(*Fn, &dbgs())); + + return true; +} + +/// Computes and returns a range of values for the induction variable in which +/// the range check can be safely elided. If it cannot compute such a range, +/// returns None. +Optional +InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE, + Instruction *Loc) const { + + // Currently we support inequalities of the form: + // + // 0 <= Offset + 1 * CIV < L given L >= 0 + // + // The inequality is satisfied by -Offset <= CIV < (L - Offset) [^1]. All + // additions and subtractions are twos-complement wrapping and comparisons are + // signed. + // + // Proof: + // + // If there exists CIV such that -Offset <= CIV < (L - Offset) then it + // follows that -Offset <= (-Offset + L) [== Eq. 1]. Since L >= 0, if + // (-Offset + L) sign-overflows then (-Offset + L) < (-Offset). Hence by + // [Eq. 1], (-Offset + L) could not have overflown. + // + // This means CIV = t + (-Offset) for t in [0, L). Hence (CIV + Offset) = + // t. Hence 0 <= (CIV + Offset) < L + + // [^1]: Note that the solution does _not_ apply if L < 0; consider values + // Offset = 127, CIV = 126 and L = -2 in an i8 world. + + const SCEVConstant *ScaleC = dyn_cast(getScale()); + if (!(ScaleC && ScaleC->getValue()->getValue() == 1)) { + DEBUG(dbgs() << "irce: could not compute safe iteration space for:\n"; dump()); + return None; + } + + Value *OffsetV = SCEVExpander(SE, "safe.itr.space").expandCodeFor( + getOffset(), getOffset()->getType(), Loc); + + Value *Begin = BinaryOperator::CreateNeg(OffsetV, "", Loc); + Value *End = BinaryOperator::CreateSub(getLength(), OffsetV, "", Loc); + + return std::make_pair(Begin, End); +} + + +static InductiveRangeCheck::Range +IntersectRange(const Optional &R1, + const InductiveRangeCheck::Range &R2, + Instruction *InsertPoint) { + if (!R1.hasValue()) return R2; + auto &R1Value = R1.getValue(); + + Value *NewMin = ConstructSMaxOf(R1Value.first, R2.first, InsertPoint); + Value *NewMax = ConstructSMinOf(R1Value.second, R2.second, InsertPoint); + return std::make_pair(NewMin, NewMax); +} + + +bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { + InductiveRangeCheck::AllocatorTy IRCAlloc; + SmallVector RangeChecks; + ScalarEvolution &SE = getAnalysis(); + + for (auto BBI : L->getBlocks()) { + if (BranchInst *TBI = dyn_cast(BBI->getTerminator())) { + if (InductiveRangeCheck *IRC = + InductiveRangeCheck::create(IRCAlloc, TBI, L, SE)) + RangeChecks.push_back(IRC); + } + } + + if (RangeChecks.empty()) return false; + + DEBUG( + dbgs() << "irce: looking at loop "; L->dump(); + dbgs() << "irce: loop has " << RangeChecks.size() << " inductive range checks: \n"; + for (InductiveRangeCheck *IRC : RangeChecks) { + IRC->dump(); + }); + + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader) { + DEBUG(dbgs() << "irce: loop has no preheader, leaving\n"); + return false; + } + + Optional SafeIterRange; + Instruction *ExprInsertPt = Preheader->getTerminator(); + +#ifndef NDEBUG + SmallVector RangeChecksToEliminate; +#endif + + for (InductiveRangeCheck *IRC : RangeChecks) { + auto Result = IRC->computeSafeIterationSpace(SE, ExprInsertPt); + if (Result.hasValue()) { + SafeIterRange = + IntersectRange(SafeIterRange, Result.getValue(), ExprInsertPt); +#ifndef NDEBUG + RangeChecksToEliminate.push_back(IRC); +#endif + } + } + + if (!SafeIterRange.hasValue()) return false; + + AssertingVH CIV = L->getCanonicalInductionVariable(); + + bool Changed = ConstrainLoopRange(SafeIterRange.getValue(), L, SE, LPM, & + getAnalysis(), + &getAnalysis(), + this); + if (Changed) { + DEBUG(dbgs() << "irce: in function " << L->getHeader()->getParent()->getParent()->getModuleIdentifier() << ": ";); + DEBUG(dbgs() << "constrained loop "; L->dump();); + + // Optimize away the now-redundant range checks. + SmallVector Dead; + simplifyUsersOfIV(CIV, &SE, &LPM, Dead); + +#ifndef NDEBUG + // In the future we'd like to assert this in debug mode and directly do the + // constant fold in produce mode. Currently there are some cases that + // simplifyUsersOfIV doesn't get. + for (InductiveRangeCheck *IRC : RangeChecksToEliminate) { + if (!isa(IRC->getBranch()->getCondition())) { + DEBUG(dbgs() << "irce: range check not eliminated: "; IRC->dump();); + } + } +#endif + } + + return Changed; +} + + +Pass* llvm::createInductiveRangeCheckEliminationPass() { + return new InductiveRangeCheckElimination; +} Index: lib/Transforms/Scalar/Scalar.cpp =================================================================== --- lib/Transforms/Scalar/Scalar.cpp +++ lib/Transforms/Scalar/Scalar.cpp @@ -40,6 +40,7 @@ initializeGVNPass(Registry); initializeEarlyCSEPass(Registry); initializeFlattenCFGPassPass(Registry); + initializeInductiveRangeCheckEliminationPass(Registry); initializeIndVarSimplifyPass(Registry); initializeJumpThreadingPass(Registry); initializeLICMPass(Registry); Index: test/Transforms/InductiveRangeCheckElimination/single.ll =================================================================== --- /dev/null +++ test/Transforms/InductiveRangeCheckElimination/single.ll @@ -0,0 +1,138 @@ +; RUN: opt -irce -S < %s | FileCheck %s + +define void @rce.0(i32 *%arr, i32 *%a_len_ptr, i32 %n) { +; CHECK-LABEL: rce.0 + entry: + %len = load i32* %a_len_ptr, !range !0 + %first.itr.check = icmp sgt i32 %n, 0 + br i1 %first.itr.check, label %loop, label %exit + + loop: +; CHECK-LABEL: loop + %idx = phi i32 [ 0, %entry ] , [ %idx.next, %in.bounds ] + %idx.next = add i32 %idx, 1 + %abc = icmp slt i32 %idx, %len + br i1 %abc, label %in.bounds, label %out.of.bounds +; CHECK: br i1 true, label %in.bounds, label %out.of.bounds + + in.bounds: + %addr = getelementptr i32* %arr, i32 %idx + store i32 0, i32* %addr + %next = icmp slt i32 %idx.next, %n + br i1 %next, label %loop, label %exit + + out.of.bounds: + ret void + + exit: + ret void +} + + +define i32 @rce.1(i8* %arr, i32 *%a_len) { +; CHECK-LABEL: rce.1 + entry: + %length.i = load i32* %a_len, !range !0 + %entry.cond = icmp ne i32 %length.i, 0 + br i1 %entry.cond, label %loop.preheader, label %exit.1 + + loop.preheader: + br label %loop + +; CHECK-LABEL: loop.preheader: +; CHECK: [[RANGE_BEGIN:[^ ]+]] = sub i32 0, -1 +; CHECK: [[RANGE_END:[^ ]+]] = sub i32 %length.i, -1 +; CHECK: [[LATCH_COUNT:[^ ]+]] = add i32 %length.i, -1 +; CHECK: [[HEADER_COUNT:[^ ]+]] = add i32 [[LATCH_COUNT]], 1 +; CHECK: [[BEGIN_CMP:[^ ]+]] = icmp slt i32 [[HEADER_COUNT]], [[RANGE_BEGIN]] +; CHECK: [[BEGIN_MAIN_LOOP:[^ ]+]] = select i1 [[BEGIN_CMP]], i32 [[HEADER_COUNT]], i32 [[RANGE_BEGIN]] +; CHECK: [[END_CMP:[^ ]+]] = icmp slt i32 [[HEADER_COUNT]], [[RANGE_END]] +; CHECK: [[END_MAIN_LOOP:[^ ]+]] = select i1 [[END_CMP]], i32 [[HEADER_COUNT]], i32 [[RANGE_END]] + +; CHECK-LABEL: preloop.exit: +; CHECK: [[PRELOOP_ENTER_COND:[^ ]+]] = icmp slt i32 %i.next.preloop, [[HEADER_COUNT]] +; CHECK: br i1 [[PRELOOP_ENTER_COND]], label %mainloop.selector, label %exit.0 + +; CHECK-LABEL: loop: +; CHECK: %i = phi i32 [ %i.next, %latch ], [ %selector.phi, %mainloop.selector ] + +; CHECK-LABEL: latch: +; CHECK: %i.next = add i32 %i, 1 +; CHECK: [[BEGIN_SLT_CIV:[^ ]+]] = icmp slt i32 %i.next, [[END_MAIN_LOOP]] +; CHECK: br i1 [[BEGIN_SLT_CIV]], label %loop, label %mainloop.exit + + +; CHECK-LABEL: mainloop.selector: +; CHECK: %selector.phi = phi i32 [ 0, %loop.preheader ], [ %i.next.preloop, %preloop.exit ] +; CHECK: [[SEL_PHI_LOWER_BOUND:[^ ]+]] = icmp sle i32 [[RANGE_BEGIN]], %selector.phi +; CHECK: [[SEL_PHI_UPPER_BOUND:[^ ]+]] = icmp slt i32 %selector.phi, [[RANGE_END]] +; CHECK: [[SEL_PHI_BOUND:[^ ]+]] = and i1 [[SEL_PHI_LOWER_BOUND]], [[SEL_PHI_UPPER_BOUND]] +; CHECK: br i1 [[SEL_PHI_BOUND:[^ ,]+]], label %loop, label %loop.postloop + +; CHECK-LABEL: loop.preloop: +; CHECK: %i.preloop = phi i32 [ %i.next.preloop, %latch.preloop ], [ 0, %loop.preheader ] + +; CHECK-LABEL: mainloop.exit: +; CHECK: [[CIV_SLT_HEADER_COUNT:[^ ]+]] = icmp slt i32 %i.next, [[HEADER_COUNT]] +; CHECK: br i1 [[CIV_SLT_HEADER_COUNT:[^ ,]+]], label %loop.postloop, label %exit.0 + +; CHECK-LABEL: loop.postloop: +; CHECK: %i.postloop = phi i32 [ %i.next.postloop, %latch.postloop ], [ %i.next, %mainloop.exit ], [ %selector.phi, %mainloop.selector ] + +; CHECK-LABEL: latch.preloop: +; CHECK: %i.next.preloop = add i32 %i.preloop, 1 +; CHECK: %latch.cond.preloop = icmp slt i32 %i.next.preloop, %length.i +; CHECK: [[CIV_SLT_BEGIN:[^ ]+]] = icmp slt i32 %i.next.preloop, [[BEGIN_MAIN_LOOP]] +; CHECK: br i1 [[CIV_SLT_BEGIN:[^ ,]+]], label %loop.preloop, label %preloop.exit + +; CHECK-LABEL: latch.postloop: +; CHECK: %i.next.postloop = add i32 %i.postloop, 1 +; CHECK: %latch.cond.postloop = icmp slt i32 %i.next.postloop, %length.i +; CHECK: br i1 %latch.cond.postloop, label %loop.postloop, label %exit.0 + + loop: + %i = phi i32 [ %i.next, %latch ], [ 0, %loop.preheader ] + %i.i64 = sext i32 %i to i64 + %idx = add nsw i64 %i.i64, 12 + %arr.addr = getelementptr inbounds i8* %arr, i64 %idx + %arr.elt = load i8* %arr.addr + %inner.cond = icmp eq i8 %arr.elt, -128 + br i1 %inner.cond, label %inner, label %latch + + inner: + %i.is.zero = icmp eq i32 %i, 0 + br i1 %i.is.zero, label %exit.2, label %not_zero + + not_zero: + %i.dec = add i32 %i, -1 + %i.dec.in.range = icmp ult i32 %i.dec, %length.i + br i1 %i.dec.in.range, label %in_bounds, label %out.of.bounds + + in_bounds: + %i.dec.i64 = sext i32 %i.dec to i64 + %idx.1 = add nsw i64 %i.dec.i64, 12 + %arr.addr.1 = getelementptr inbounds i8* %arr, i64 %idx.1 + %arr.elt.1 = load i8* %arr.addr.1 + %inner.cond.1 = icmp slt i8 %arr.elt.1, 0 + br i1 %inner.cond.1, label %latch, label %exit.2 + + latch: + %i.next = add i32 %i, 1 + %latch.cond = icmp slt i32 %i.next, %length.i + br i1 %latch.cond, label %loop, label %exit.0 + + exit.0: + ret i32 %i.next + + exit.1: + ret i32 43 + + exit.2: + ret i32 44 + + out.of.bounds: + ret i32 45 + +} + +!0 = !{i32 0, i32 2147483647}