Index: llvm/include/llvm/CodeGen/ReachingDefAnalysis.h =================================================================== --- llvm/include/llvm/CodeGen/ReachingDefAnalysis.h +++ llvm/include/llvm/CodeGen/ReachingDefAnalysis.h @@ -109,6 +109,17 @@ /// reaching def instuction of PhysReg that reaches MI. int getClearance(MachineInstr *MI, MCPhysReg PhysReg); + /// Return whether A and B use the same def of PhysReg. + bool hasSameReachingDef(MachineInstr *A, MachineInstr *B, int PhysReg); + + /// Provides the uses, in the same block as MI, of register that MI defines. + void getReachingUses(MachineInstr *MI, int PhysReg, + SmallVectorImpl &Uses); + + /// Provide the number of uses, in the same block as MI, of the register that + /// MI defines. + unsigned getNumUses(MachineInstr *MI, int PhysReg); + private: /// Set up LiveRegs by merging predecessor live-out values. void enterBasicBlock(const LoopTraversal::TraversedMBBInfo &TraversedMBB); Index: llvm/lib/CodeGen/ReachingDefAnalysis.cpp =================================================================== --- llvm/lib/CodeGen/ReachingDefAnalysis.cpp +++ llvm/lib/CodeGen/ReachingDefAnalysis.cpp @@ -220,3 +220,39 @@ assert(InstIds.count(MI) && "Unexpected machine instuction."); return InstIds[MI] - getReachingDef(MI, PhysReg); } + +void +ReachingDefAnalysis::getReachingUses(MachineInstr *Def, int PhysReg, + SmallVectorImpl &Uses) { + MachineBasicBlock *MBB = Def->getParent(); + MachineBasicBlock::iterator MI = MachineBasicBlock::iterator(Def); + while (++MI != MBB->end()) { + for (auto &MO : MI->operands()) { + if (!MO.isReg() || !MO.isUse() || MO.getReg() != PhysReg) + continue; + + if (getReachingMIDef(&*MI, PhysReg) != Def) + return; + + Uses.push_back(&*MI); + if (MO.isKill()) + return; + } + } +} + +unsigned ReachingDefAnalysis::getNumUses(MachineInstr *Def, int PhysReg) { + SmallVector Uses; + getReachingUses(Def, PhysReg, Uses); + return Uses.size(); +} + +bool ReachingDefAnalysis::hasSameReachingDef(MachineInstr *A, MachineInstr *B, + int PhysReg) { + MachineBasicBlock *ParentA = A->getParent(); + MachineBasicBlock *ParentB = B->getParent(); + if (ParentA != ParentB) + return false; + + return getReachingDef(A, PhysReg) == getReachingDef(B, PhysReg); +} Index: llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp =================================================================== --- llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp +++ llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp @@ -25,6 +25,7 @@ #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineLoopInfo.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/ReachingDefAnalysis.h" #include "llvm/MC/MCInstrDesc.h" using namespace llvm; @@ -37,6 +38,7 @@ struct LowOverheadLoop { MachineLoop *ML = nullptr; + ReachingDefAnalysis *RDA = nullptr; MachineFunction *MF = nullptr; MachineInstr *InsertPt = nullptr; MachineInstr *Start = nullptr; @@ -48,7 +50,8 @@ bool FoundOneVCTP = false; bool CannotTailPredicate = false; - LowOverheadLoop(MachineLoop *ML) : ML(ML) { + LowOverheadLoop(MachineLoop *ML, ReachingDefAnalysis *RDA) + : ML(ML), RDA(RDA) { MF = ML->getHeader()->getParent(); } @@ -113,6 +116,35 @@ return Start && Dec && End; } + // Return the loop iteration count, or the number of elements if we're tail + // predicating. + MachineOperand &getCount() { + return IsTailPredicationLegal() ? + VCTP->getOperand(1) : Start->getOperand(0); + } + + unsigned getStartOpcode() const { + bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart; + if (!IsTailPredicationLegal()) + return IsDo ? ARM::t2DLS : ARM::t2WLS; + else { + switch (VCTP->getOpcode()) { + default: + llvm_unreachable("unhandled vctp opcode"); + break; + case ARM::MVE_VCTP8: + return IsDo ? ARM::MVE_DLSTP_8 : ARM::MVE_WLSTP_8; + case ARM::MVE_VCTP16: + return IsDo ? ARM::MVE_DLSTP_16 : ARM::MVE_WLSTP_16; + case ARM::MVE_VCTP32: + return IsDo ? ARM::MVE_DLSTP_32 : ARM::MVE_WLSTP_32; + case ARM::MVE_VCTP64: + return IsDo ? ARM::MVE_DLSTP_64 : ARM::MVE_WLSTP_64; + } + } + return 0; + } + void dump() const { if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start; if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec; @@ -127,6 +159,7 @@ class ARMLowOverheadLoops : public MachineFunctionPass { MachineFunction *MF = nullptr; + ReachingDefAnalysis *RDA = nullptr; const ARMBaseInstrInfo *TII = nullptr; MachineRegisterInfo *MRI = nullptr; std::unique_ptr BBUtils = nullptr; @@ -139,6 +172,7 @@ void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addRequired(); + AU.addRequired(); MachineFunctionPass::getAnalysisUsage(AU); } @@ -146,7 +180,8 @@ MachineFunctionProperties getRequiredProperties() const override { return MachineFunctionProperties().set( - MachineFunctionProperties::Property::NoVRegs); + MachineFunctionProperties::Property::NoVRegs).set( + MachineFunctionProperties::Property::TracksLiveness); } StringRef getPassName() const override { @@ -320,13 +355,51 @@ return; } - InsertPt = Revert ? nullptr : IsSafeToDefineLR(); + InsertPt = IsSafeToDefineLR(); if (!InsertPt) { LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n"); Revert = true; + return; } else LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt); + // For tail predication, we need to provide the number of elements, instead + // of the iteration count, to the loop start instruction. The number of + // elements is provided to the vctp instruction, so we need to check that + // we can use this register at InsertPt. + if (IsTailPredicationLegal()) { + Register NumElements = VCTP->getOperand(0).getReg(); + + // If the register is defined within loop, then we can't perform TP. + if (RDA->getReachingDef(VCTP, NumElements) >= 0) + CannotTailPredicate = true; + else { + MachineBasicBlock *InsertBB = InsertPt->getParent(); + // We can't perform TP if the register does not hold the same value at + // InsertPt as the liveout value. + if (!RDA->hasSameReachingDef(InsertPt, &InsertBB->back(), + NumElements)) + CannotTailPredicate = true; + else { + // Especially in the case of while loops, InsertBB may not be the + // preheader, so we need to check that the register isn't redefined + // before entering the loop. + MachineBasicBlock *MBB = ML->getLoopPreheader(); + while (MBB != InsertBB) { + if (RDA->getReachingDef(&MBB->back(), NumElements) >= 0) { + CannotTailPredicate = true; + break; + } + if (MBB->pred_size() > 1) { + CannotTailPredicate = true; + break; + } + MBB = *MBB->pred_begin(); + } + } + } + } + LLVM_DEBUG(if (IsTailPredicationLegal()) { dbgs() << "ARM Loops: Will use tail predication to convert:\n"; for (auto *MI : VPTUsers) @@ -343,6 +416,7 @@ LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n"); auto &MLI = getAnalysis(); + RDA = &getAnalysis(); MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness); MRI = &MF->getRegInfo(); TII = static_cast(ST.getInstrInfo()); @@ -382,7 +456,7 @@ return nullptr; }; - LowOverheadLoop LoLoop(ML); + LowOverheadLoop LoLoop(ML, RDA); // Search the preheader for the start intrinsic, or look through the // predecessors of the header to find exactly one set.iterations intrinsic. // FIXME: I don't see why we shouldn't be supporting multiple predecessors @@ -558,35 +632,45 @@ MachineInstr *Start = LoLoop.Start; MachineBasicBlock *MBB = InsertPt->getParent(); bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart; - unsigned Opc = 0; - - if (!LoLoop.IsTailPredicationLegal()) - Opc = IsDo ? ARM::t2DLS : ARM::t2WLS; - else { - switch (LoLoop.VCTP->getOpcode()) { - case ARM::MVE_VCTP8: - Opc = IsDo ? ARM::MVE_DLSTP_8 : ARM::MVE_WLSTP_8; - break; - case ARM::MVE_VCTP16: - Opc = IsDo ? ARM::MVE_DLSTP_16 : ARM::MVE_WLSTP_16; - break; - case ARM::MVE_VCTP32: - Opc = IsDo ? ARM::MVE_DLSTP_32 : ARM::MVE_WLSTP_32; - break; - case ARM::MVE_VCTP64: - Opc = IsDo ? ARM::MVE_DLSTP_64 : ARM::MVE_WLSTP_64; - break; - } - } + unsigned Opc = LoLoop.getStartOpcode(); + MachineOperand &Count = LoLoop.getCount(); MachineInstrBuilder MIB = BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc)); MIB.addDef(ARM::LR); - MIB.add(Start->getOperand(0)); + MIB.add(Count); if (!IsDo) MIB.add(Start->getOperand(1)); + // When using tail-predication, try to delete the dead code that was used to + // calculate the number of loop iterations. + if (LoLoop.IsTailPredicationLegal()) { + SmallVector Killed; + SmallVector Dead; + if (auto *Def = RDA->getReachingMIDef(Start, + Start->getOperand(0).getReg())) { + Killed.push_back(Def); + + while (!Killed.empty()) { + MachineInstr *Def = Killed.back(); + Killed.pop_back(); + Dead.push_back(Def); + for (auto &MO : Def->operands()) { + if (!MO.isReg() || !MO.isKill()) + continue; + + MachineInstr *Kill = RDA->getReachingMIDef(Def, MO.getReg()); + if (Kill && RDA->getNumUses(Kill, MO.getReg()) == 1) + Killed.push_back(Kill); + } + } + for (auto *MI : Dead) + MI->eraseFromParent(); + } + } + + // If we're inserting at a mov lr, then remove it as it's redundant. if (InsertPt != Start) InsertPt->eraseFromParent(); Start->eraseFromParent(); Index: llvm/test/CodeGen/Thumb2/LowOverheadLoops/fast-fp-loops.ll =================================================================== --- llvm/test/CodeGen/Thumb2/LowOverheadLoops/fast-fp-loops.ll +++ llvm/test/CodeGen/Thumb2/LowOverheadLoops/fast-fp-loops.ll @@ -36,11 +36,7 @@ ; CHECK-NEXT: mov.w r12, #0 ; CHECK-NEXT: b .LBB0_8 ; CHECK-NEXT: .LBB0_4: @ %vector.ph -; CHECK-NEXT: adds r6, r3, #3 -; CHECK-NEXT: bic r6, r6, #3 -; CHECK-NEXT: subs r6, #4 -; CHECK-NEXT: add.w lr, r12, r6, lsr #2 -; CHECK-NEXT: dlstp.32 lr, lr +; CHECK-NEXT: dlstp.32 lr, r3 ; CHECK-NEXT: .LBB0_5: @ %vector.body ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: vldrw.u32 q0, [r1] Index: llvm/test/CodeGen/Thumb2/LowOverheadLoops/mve-tail-data-types.ll =================================================================== --- llvm/test/CodeGen/Thumb2/LowOverheadLoops/mve-tail-data-types.ll +++ llvm/test/CodeGen/Thumb2/LowOverheadLoops/mve-tail-data-types.ll @@ -349,13 +349,8 @@ ; CHECK-NEXT: moveq r0, #0 ; CHECK-NEXT: bxeq lr ; CHECK-NEXT: push {r7, lr} -; CHECK-NEXT: adds r3, r2, #3 ; CHECK-NEXT: vmov.i32 q0, #0x0 -; CHECK-NEXT: bic r3, r3, #3 -; CHECK-NEXT: sub.w r12, r3, #4 -; CHECK-NEXT: movs r3, #1 -; CHECK-NEXT: add.w lr, r3, r12, lsr #2 -; CHECK-NEXT: dlstp.32 lr, lr +; CHECK-NEXT: dlstp.32 lr, r2 ; CHECK-NEXT: .LBB4_1: @ %vector.body ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: vldrw.u32 q2, [r1] @@ -1126,11 +1121,7 @@ ; CHECK-NEXT: mov.w r12, #0 ; CHECK-NEXT: b .LBB9_8 ; CHECK-NEXT: .LBB9_4: @ %vector.ph -; CHECK-NEXT: add.w r4, r12, #3 -; CHECK-NEXT: bic r4, r4, #3 -; CHECK-NEXT: subs r4, #4 -; CHECK-NEXT: add.w lr, lr, r4, lsr #2 -; CHECK-NEXT: dlstp.32 lr, lr +; CHECK-NEXT: dlstp.32 lr, r12 ; CHECK-NEXT: .LBB9_5: @ %vector.body ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: vldrw.u32 q0, [r0] Index: llvm/test/CodeGen/Thumb2/LowOverheadLoops/vector-arith-codegen.ll =================================================================== --- llvm/test/CodeGen/Thumb2/LowOverheadLoops/vector-arith-codegen.ll +++ llvm/test/CodeGen/Thumb2/LowOverheadLoops/vector-arith-codegen.ll @@ -9,13 +9,8 @@ ; CHECK-NEXT: moveq r0, #0 ; CHECK-NEXT: bxeq lr ; CHECK-NEXT: push {r7, lr} -; CHECK-NEXT: adds r3, r2, #3 ; CHECK-NEXT: vmov.i32 q0, #0x0 -; CHECK-NEXT: bic r3, r3, #3 -; CHECK-NEXT: sub.w r12, r3, #4 -; CHECK-NEXT: movs r3, #1 -; CHECK-NEXT: add.w lr, r3, r12, lsr #2 -; CHECK-NEXT: dlstp.32 lr, lr +; CHECK-NEXT: dlstp.32 lr, r2 ; CHECK-NEXT: .LBB0_1: @ %vector.body ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: vmov q1, q0 @@ -82,13 +77,8 @@ ; CHECK-NEXT: moveq r0, #0 ; CHECK-NEXT: bxeq lr ; CHECK-NEXT: push {r7, lr} -; CHECK-NEXT: adds r1, r2, #3 -; CHECK-NEXT: movs r3, #1 -; CHECK-NEXT: bic r1, r1, #3 ; CHECK-NEXT: vmov.i32 q0, #0x0 -; CHECK-NEXT: subs r1, #4 -; CHECK-NEXT: add.w lr, r3, r1, lsr #2 -; CHECK-NEXT: dlstp.32 lr, lr +; CHECK-NEXT: dlstp.32 lr, r2 ; CHECK-NEXT: .LBB1_1: @ %vector.body ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: mov r1, r2 @@ -148,13 +138,8 @@ ; CHECK-NEXT: moveq r0, #0 ; CHECK-NEXT: bxeq lr ; CHECK-NEXT: push {r7, lr} -; CHECK-NEXT: adds r1, r2, #3 -; CHECK-NEXT: movs r3, #1 -; CHECK-NEXT: bic r1, r1, #3 ; CHECK-NEXT: vmov.i32 q0, #0x0 -; CHECK-NEXT: subs r1, #4 -; CHECK-NEXT: add.w lr, r3, r1, lsr #2 -; CHECK-NEXT: dlstp.32 lr, lr +; CHECK-NEXT: dlstp.32 lr, r2 ; CHECK-NEXT: .LBB2_1: @ %vector.body ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: mov r1, r2 @@ -213,12 +198,7 @@ ; CHECK-NEXT: cmp r3, #0 ; CHECK-NEXT: it eq ; CHECK-NEXT: popeq {r7, pc} -; CHECK-NEXT: add.w r12, r3, #3 -; CHECK-NEXT: mov.w lr, #1 -; CHECK-NEXT: bic r12, r12, #3 -; CHECK-NEXT: sub.w r12, r12, #4 -; CHECK-NEXT: add.w lr, lr, r12, lsr #2 -; CHECK-NEXT: dlstp.32 lr, lr +; CHECK-NEXT: dlstp.32 lr, r3 ; CHECK-NEXT: .LBB3_1: @ %vector.body ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: vldrw.u32 q0, [r1] @@ -272,12 +252,7 @@ ; CHECK-NEXT: cmp r3, #0 ; CHECK-NEXT: it eq ; CHECK-NEXT: popeq {r7, pc} -; CHECK-NEXT: add.w r12, r3, #3 -; CHECK-NEXT: mov.w lr, #1 -; CHECK-NEXT: bic r12, r12, #3 -; CHECK-NEXT: sub.w r12, r12, #4 -; CHECK-NEXT: add.w lr, lr, r12, lsr #2 -; CHECK-NEXT: dlstp.32 lr, lr +; CHECK-NEXT: dlstp.32 lr, r3 ; CHECK-NEXT: .LBB4_1: @ %vector.body ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: vldrw.u32 q0, [r1] @@ -331,13 +306,8 @@ ; CHECK-NEXT: cmp r3, #0 ; CHECK-NEXT: it eq ; CHECK-NEXT: popeq {r4, pc} -; CHECK-NEXT: add.w r12, r3, #15 -; CHECK-NEXT: mov.w lr, #1 -; CHECK-NEXT: bic r12, r12, #15 -; CHECK-NEXT: sub.w r12, r12, #16 -; CHECK-NEXT: add.w lr, lr, r12, lsr #4 ; CHECK-NEXT: mov.w r12, #0 -; CHECK-NEXT: dlstp.8 lr, lr +; CHECK-NEXT: dlstp.8 lr, r3 ; CHECK-NEXT: .LBB5_1: @ %vector.body ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: add.w r4, r1, r12 @@ -396,12 +366,7 @@ ; CHECK-NEXT: cmp r3, #0 ; CHECK-NEXT: it eq ; CHECK-NEXT: popeq {r7, pc} -; CHECK-NEXT: add.w r12, r3, #7 -; CHECK-NEXT: mov.w lr, #1 -; CHECK-NEXT: bic r12, r12, #7 -; CHECK-NEXT: sub.w r12, r12, #8 -; CHECK-NEXT: add.w lr, lr, r12, lsr #3 -; CHECK-NEXT: dlstp.16 lr, lr +; CHECK-NEXT: dlstp.16 lr, r3 ; CHECK-NEXT: .LBB6_1: @ %vector.body ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: vldrh.u16 q0, [r1]