Index: llvm/lib/Target/ARM/MVETailPredication.cpp =================================================================== --- llvm/lib/Target/ARM/MVETailPredication.cpp +++ llvm/lib/Target/ARM/MVETailPredication.cpp @@ -35,12 +35,14 @@ #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/InitializePasses.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicsARM.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; @@ -56,8 +58,12 @@ class MVETailPredication : public LoopPass { SmallVector MaskedInsts; Loop *L = nullptr; + LoopInfo *LI = nullptr; + const DataLayout *DL; + DominatorTree *DT = nullptr; ScalarEvolution *SE = nullptr; TargetTransformInfo *TTI = nullptr; + TargetLibraryInfo *TLI = nullptr; public: static char ID; @@ -69,8 +75,12 @@ AU.addRequired(); AU.addRequired(); AU.addRequired(); + AU.addRequired(); + AU.addRequired(); AU.addPreserved(); + AU.setPreservesCFG(); + getLoopAnalysisUsage(AU); } bool runOnLoop(Loop *L, LPPassManager&) override; @@ -128,8 +138,13 @@ auto &TPC = getAnalysis(); auto &TM = TPC.getTM(); auto *ST = &TM.getSubtarget(F); + DT = &getAnalysis().getDomTree(); + LI = &getAnalysis().getLoopInfo(); TTI = &getAnalysis().getTTI(F); SE = &getAnalysis().getSE(); + auto *TLIP = getAnalysisIfAvailable(); + TLI = TLIP ? &TLIP->getTLI(*L->getHeader()->getParent()) : nullptr; + DL = &L->getHeader()->getModule()->getDataLayout(); this->L = L; // The MVE and LOB extensions are combined to enable tail-predication, but @@ -185,7 +200,18 @@ LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n" << *Decrement << "\n"); - return TryConvert(Setup->getArgOperand(0)); + + if (TryConvert(Setup->getArgOperand(0))) { + SmallVector DeadInsts; + SCEVExpander Rewriter(*SE, *DL, "mvetp"); + ReplaceExitVal ReplaceExitValue = AlwaysRepl; + + rewriteLoopExitValues(L, LI, TLI, SE, Rewriter, DT, ReplaceExitValue, + DeadInsts); + return true; + } + + return false; } bool MVETailPredication::isTailPredicate(Instruction *I, Value *NumElements) { @@ -549,4 +575,5 @@ char MVETailPredication::ID = 0; INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false)