Index: lib/Transforms/Scalar/LoopIdiomRecognize.cpp =================================================================== --- lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -55,10 +55,13 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "loop-idiom" @@ -75,6 +78,7 @@ ScalarEvolution *SE; TargetLibraryInfo *TLI; const TargetTransformInfo *TTI; + LPPassManager *LPM; public: static char ID; @@ -134,6 +138,8 @@ void transformLoopToPopcount(BasicBlock *PreCondBB, Instruction *CntInst, PHINode *CntPhi, Value *Var); + bool recognizeStrlen(); + /// @} }; @@ -186,7 +192,7 @@ // Disable loop idiom recognition if the function's name is a common idiom. StringRef Name = L->getHeader()->getParent()->getName(); - if (Name == "memset" || Name == "memcpy") + if (Name == "memset" || Name == "memcpy" || Name == "strlen") return false; AA = &getAnalysis(); @@ -196,6 +202,7 @@ TLI = &getAnalysis().getTLI(); TTI = &getAnalysis().getTTI( *CurLoop->getHeader()->getParent()); + this->LPM = &LPM; if (SE->hasLoopInvariantBackedgeTakenCount(L)) return runOnCountableLoop(); @@ -654,6 +661,9 @@ if (recognizePopcount()) return true; + if (recognizeStrlen()) + return true; + return false; } @@ -989,3 +999,153 @@ // loop. The loop would otherwise not be deleted even if it becomes empty. SE->forgetLoop(CurLoop); } + +bool LoopIdiomRecognize::recognizeStrlen() { + // If we're not allowed to introduce a strlen, don't try. + if (!TLI->has(LibFunc::strlen)) + return false; + + // Give up if the loop has multiple blocks or multiple backedges. + if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1) + return false; + + BasicBlock *LoopBody = *(CurLoop->block_begin()); + if (LoopBody->size() >= 20) { + // The loop is too big, bail out. + return false; + } + + // If successful, there won't be a loop to take the address of. + if (LoopBody->hasAddressTaken()) + return false; + + // It should have a preheader containing nothing but an unconditional branch. + BasicBlock *PH = CurLoop->getLoopPreheader(); + if (!PH) + return false; + if (&PH->front() != PH->getTerminator()) + return false; + auto *EntryBI = dyn_cast(PH->getTerminator()); + if (!EntryBI || EntryBI->isConditional()) + return false; + + // Check that this loop has a pointer that increments by one byte each + // iteration and exits the loop when a load of our pointer returns i8 0. + auto *Backedge = dyn_cast(LoopBody->getTerminator()); + if (!Backedge || Backedge->isUnconditional()) + return false; + auto *LoopCond = Backedge->getCondition(); + Value *Load, *Ptr; + CmpInst::Predicate Pred; + if (!(match(LoopCond, m_ICmp(Pred, m_Value(Load), m_Zero())) && + Pred == CmpInst::ICMP_EQ && isa(Load) && + (Ptr = cast(Load)->getPointerOperand()) && + cast(Ptr)->getParent() == LoopBody)) + return false; + if (Ptr->getType() != Type::getInt8PtrTy(Ptr->getContext())) + return false; + GetElementPtrInst *PtrGEP = dyn_cast(Ptr); + auto &DL = CurLoop->getHeader()->getModule()->getDataLayout(); + APInt Offset(DL.getPointerTypeSizeInBits(Ptr->getType()), 0); + if (!PtrGEP || !PtrGEP->accumulateConstantOffset(DL, Offset) || Offset != 1) + return false; + PHINode *PtrPHI = dyn_cast(PtrGEP->getPointerOperand()); + if (!PtrPHI || PtrPHI->getIncomingValueForBlock(LoopBody) != PtrGEP) + return false; + + // Check that this loop does nothing else each iteration. + for (Instruction *Inst = LoopBody->begin(); !isa(Inst); + Inst = Inst->getNextNode()) { + if (Inst == PtrPHI || Inst == Load || Inst == LoopCond) { + if (!Inst->hasOneUse()) + return false; + continue; + } + if (!Inst->getType()->isIntOrIntVectorTy() && + !Inst->getType()->isPointerTy()) + return false; + if (Inst->mayHaveSideEffects()) + return false; + if (PHINode *IndVar = dyn_cast(Inst)) { + Value *Start = IndVar->getIncomingValueForBlock(PH); + if (!CurLoop->isLoopInvariant(Start)) + return false; + Value *Step; + if (!match(IndVar->getIncomingValueForBlock(LoopBody), + m_Add(m_Specific(IndVar), m_Value(Step)))) + return false; + if (!CurLoop->isLoopInvariant(Step)) + return false; + continue; + } + switch (Inst->getOpcode()) { + case Instruction::Add: + case Instruction::GetElementPtr: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::Trunc: + for (unsigned i = 0, e = Inst->getNumOperands(); i != e; ++i) { + Instruction *Op = dyn_cast(Inst->getOperand(i)); + if (!Op) continue; + if (Op->getParent() != LoopBody) + return false; + } + break; + default: + return false; + } + } + + // Rewrite the integer counters in terms of strlen(). + IRBuilder<> Builder(PH->getTerminator()); + Value *StrLen = llvm::EmitStrLen(PtrPHI->getIncomingValueForBlock(PH), + Builder, DL, TLI); + StrLen = Builder.CreateSub(StrLen, ConstantInt::get(StrLen->getType(), 1)); + for (PHINode *IndVar = dyn_cast(LoopBody->begin()); + IndVar; + IndVar = dyn_cast(IndVar->getNextNode())) { + if (IndVar == PtrPHI) + continue; + + Value *Start = IndVar->getIncomingValueForBlock(PH); + Value *Step; + if (!match(IndVar->getIncomingValueForBlock(LoopBody), + m_Add(m_Specific(IndVar), m_Value(Step)))) + llvm_unreachable("Failed to find loop step that we already found"); + + Value *StrLenV = + Builder.CreateZExtOrTrunc(StrLen, IndVar->getType()->getScalarType()); + if (VectorType *VTy = dyn_cast(IndVar->getType())) + StrLenV = Builder.CreateVectorSplat(VTy->getNumElements(), StrLenV); + Value *Result = Builder.CreateAdd(Start, Builder.CreateMul(Step, StrLenV)); + + IndVar->replaceAllUsesWith(Result); + } + + BasicBlock *ExitBlock = CurLoop->getExitBlock(); + assert(ExitBlock && "Loop ends with branch but has multiple exit blocks?"); + + SE->forgetLoop(CurLoop); + LPM->deleteLoopFromQueue(CurLoop); + + // We know these instructions are certainly dead. So is the pointer increment, + // but we didn't keep a Value* around for that. + LoopCond->replaceAllUsesWith(UndefValue::get(LoopCond->getType())); + Load->replaceAllUsesWith(UndefValue::get(Load->getType())); + PtrPHI->replaceAllUsesWith(UndefValue::get(PtrPHI->getType())); + cast(LoopCond)->eraseFromParent(); + cast(Load)->eraseFromParent(); + PtrPHI->eraseFromParent(); + + PH->getInstList().splice(Builder.GetInsertPoint(), LoopBody->getInstList(), + LoopBody->getFirstNonPHI(), + LoopBody->getTerminator()); + + // At this stage, LoopBody contains only use_empty PHI nodes and a terminator. + + LoopBody->replaceSuccessorsPhiUsesWith(PH); + LoopBody->replaceAllUsesWith(ExitBlock); + LoopBody->eraseFromParent(); + + return true; +} Index: test/Transforms/LoopIdiom/strlen.ll =================================================================== --- test/Transforms/LoopIdiom/strlen.ll +++ test/Transforms/LoopIdiom/strlen.ll @@ -0,0 +1,35 @@ +; RUN: opt -loop-idiom -S < %s | FileCheck %s + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +define i32 @test1(i8* %p) { +; CHECK-LABEL: @test1 +entry: + %0 = load i8, i8* %p, align 1 + %tobool.5 = icmp eq i8 %0, 0 + br i1 %tobool.5, label %for.cond.cleanup, label %for.inc.lr.ph + +for.inc.lr.ph: ; preds = %entry +; CHECK: for.inc.lr.ph: +; CHECK-NEXT: @strlen +; CHECK-NOT: icmp + br label %for.inc + +for.cond.for.cond.cleanup_crit_edge: ; preds = %for.inc + %inc.lcssa = phi i32 [ %inc, %for.inc ] + br label %for.cond.cleanup + +for.cond.cleanup: ; preds = %for.cond.for.cond.cleanup_crit_edge, %entry + %len.0.lcssa = phi i32 [ %inc.lcssa, %for.cond.for.cond.cleanup_crit_edge ], [ 0, %entry ] + ret i32 %len.0.lcssa + +for.inc: ; preds = %for.inc.lr.ph, %for.inc + %ptr.07 = phi i8* [ %p, %for.inc.lr.ph ], [ %incdec.ptr, %for.inc ] + %len.06 = phi i32 [ 0, %for.inc.lr.ph ], [ %inc, %for.inc ] + %incdec.ptr = getelementptr inbounds i8, i8* %ptr.07, i64 1 + %inc = add nuw nsw i32 %len.06, 1 + %1 = load i8, i8* %incdec.ptr, align 1 + %tobool = icmp eq i8 %1, 0 + br i1 %tobool, label %for.cond.for.cond.cleanup_crit_edge, label %for.inc +}