Index: llvm/include/llvm/CodeGen/ReachingDefAnalysis.h =================================================================== --- llvm/include/llvm/CodeGen/ReachingDefAnalysis.h +++ llvm/include/llvm/CodeGen/ReachingDefAnalysis.h @@ -189,6 +189,7 @@ /// definition is found, recursively search the predecessor blocks for them. void getLiveOuts(MachineBasicBlock *MBB, int PhysReg, InstSet &Defs, BlockSet &VisitedBBs) const; + void getLiveOuts(MachineBasicBlock *MBB, int PhysReg, InstSet &Defs) const; /// For the given block, collect the instructions that use the live-in /// value of the provided register. Return whether the value is still Index: llvm/lib/CodeGen/ReachingDefAnalysis.cpp =================================================================== --- llvm/lib/CodeGen/ReachingDefAnalysis.cpp +++ llvm/lib/CodeGen/ReachingDefAnalysis.cpp @@ -389,6 +389,12 @@ } } +void ReachingDefAnalysis::getLiveOuts(MachineBasicBlock *MBB, int PhysReg, + InstSet &Defs) const { + SmallPtrSet VisitedBBs; + getLiveOuts(MBB, PhysReg, Defs, VisitedBBs); +} + void ReachingDefAnalysis::getLiveOuts(MachineBasicBlock *MBB, int PhysReg, InstSet &Defs, BlockSet &VisitedBBs) const { Index: llvm/lib/Target/ARM/ARMBaseInstrInfo.h =================================================================== --- llvm/lib/Target/ARM/ARMBaseInstrInfo.h +++ llvm/lib/Target/ARM/ARMBaseInstrInfo.h @@ -679,7 +679,6 @@ static inline bool isMovRegOpcode(int Opc) { return Opc == ARM::MOVr || Opc == ARM::tMOVr || Opc == ARM::t2MOVr; } - /// isValidCoprocessorNumber - decide whether an explicit coprocessor /// number is legal in generic instructions like CDP. The answer can /// vary with the subtarget. Index: llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp =================================================================== --- llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp +++ llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp @@ -201,12 +201,25 @@ PredicatedMI *getDivergent() const { return Divergent; } }; + struct Reduction { + MachineInstr *Init; + MachineInstr &Copy; + MachineInstr &Reduce; + MachineInstr &VPSEL; + + Reduction(MachineInstr *Init, MachineInstr *Mov, MachineInstr *Add, + MachineInstr *Sel) + : Init(Init), Copy(*Mov), Reduce(*Add), VPSEL(*Sel) { } + }; + struct LowOverheadLoop { MachineLoop &ML; + MachineBasicBlock *Preheader = nullptr; MachineLoopInfo &MLI; ReachingDefAnalysis &RDA; const TargetRegisterInfo &TRI; + const ARMBaseInstrInfo &TII; MachineFunction *MF = nullptr; MachineInstr *InsertPt = nullptr; MachineInstr *Start = nullptr; @@ -218,14 +231,20 @@ SetVector CurrentPredicate; SmallVector VPTBlocks; SmallPtrSet ToRemove; + SmallVector, 1> Reductions; SmallPtrSet BlockMasksToRecompute; bool Revert = false; bool CannotTailPredicate = false; LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI, - ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI) - : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI) { + ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI, + const ARMBaseInstrInfo &TII) + : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI), TII(TII) { MF = ML.getHeader()->getParent(); + if (auto *MBB = ML.getLoopPreheader()) + Preheader = MBB; + else if (auto *MBB = MLI.findLoopPreheader(&ML, true)) + Preheader = MBB; } // If this is an MVE instruction, check that we know how to use tail @@ -249,9 +268,13 @@ // of elements to the loop start instruction. bool ValidateTailPredicate(MachineInstr *StartInsertPt); + // See whether the live-out instructions are a reduction that we can fixup + // later. + bool FindValidReduction(InstSet &LiveMIs, InstSet &LiveOutUsers); + // Check that any values available outside of the loop will be the same // after tail predication conversion. - bool ValidateLiveOuts() const; + bool ValidateLiveOuts(); // Is it safe to define LR with DLS/WLS? // LR can be defined if it is the operand to start, because it's the same @@ -341,6 +364,8 @@ void ConvertVPTBlocks(LowOverheadLoop &LoLoop); + void FixupReductions(LowOverheadLoop &LoLoop) const; + MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop); void Expand(LowOverheadLoop &LoLoop); @@ -481,7 +506,7 @@ }; // First, find the block that looks like the preheader. - MachineBasicBlock *MBB = MLI.findLoopPreheader(&ML, true); + MachineBasicBlock *MBB = Preheader; if (!MBB) { LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find preheader.\n"); return false; @@ -626,7 +651,85 @@ return true; } -bool LowOverheadLoop::ValidateLiveOuts() const { +bool +LowOverheadLoop::FindValidReduction(InstSet &LiveMIs, InstSet &LiveOutUsers) { + // Also check for reductions where the operation needs to be merging values + // from the last and previous loop iterations. This means an instruction + // producing a value and a vmov storing the value calculated in the previous + // iteration. So we can have two live-out regs, one produced by a vmov and + // both being consumed by a vpsel. + LLVM_DEBUG(dbgs() << "ARM Loops: Looking for reduction live-outs:\n"; + for (auto *MI : LiveMIs) + dbgs() << " - " << *MI); + + // Expect a vmov, a vadd and a single vpsel user. + if (LiveMIs.size() != 2 || LiveOutUsers.size() != 1) + return false; + + MachineInstr *VPSEL = *LiveOutUsers.begin(); + if (VPSEL->getOpcode() != ARM::MVE_VPSEL) + return false; + + unsigned VPRIdx = llvm::findFirstVPTPredOperandIdx(*VPSEL) + 1; + MachineInstr *Pred = RDA.getMIOperand(VPSEL, VPRIdx); + if (!Pred || Pred != VCTP) { + LLVM_DEBUG(dbgs() << "ARM Loops: Not using equivalent predicate.\n"); + return false; + } + + MachineInstr *Reduce = RDA.getMIOperand(VPSEL, 1); + if (!Reduce) + return false; + + // TODO: Support more operations that VADD. + if (Reduce->getOpcode() != ARM::MVE_VADDi32) + return false; + + // Check that the VORR is actually a VMOV. + MachineInstr *Copy = RDA.getMIOperand(VPSEL, 2); + if (!Copy || Copy->getOpcode() != ARM::MVE_VORR || + !Copy->getOperand(1).isReg() || !Copy->getOperand(2).isReg() || + Copy->getOperand(1).getReg() != Copy->getOperand(2).getReg()) + return false; + + assert((LiveMIs.count(Reduce) && LiveMIs.count(Copy)) && + "Expected live outs to be consumed by vpsel"); + + assert((Reduce->getOperand(0).getReg() == Reduce->getOperand(1).getReg() || + Reduce->getOperand(0).getReg() == Reduce->getOperand(2).getReg()) && + "Expected VADD to be overwriting one of its operands"); + + // Check that the vadd and vmov are only used by each other and the vpsel. + SmallPtrSet CopyUsers; + RDA.getGlobalUses(Copy, Copy->getOperand(0).getReg(), CopyUsers); + if (CopyUsers.size() > 2 || !CopyUsers.count(Reduce)) + return false; + + SmallPtrSet ReduceUsers; + RDA.getGlobalUses(Reduce, Reduce->getOperand(0).getReg(), ReduceUsers); + if (ReduceUsers.size() > 2 || !ReduceUsers.count(Copy)) + return false; + + // Then find whether there's an instruction initialising the register that + // is storing the reduction. + if (!Preheader) + return false; + + SmallPtrSet Incoming; + RDA.getLiveOuts(Preheader, Copy->getOperand(1).getReg(), Incoming); + if (Incoming.size() > 1) + return false; + + MachineInstr *Init = Incoming.empty() ? nullptr : *Incoming.begin(); + LLVM_DEBUG(dbgs() << "ARM Loops: Found a reduction:\n" + << " - " << *Copy + << " - " << *Reduce + << " - " << *VPSEL); + Reductions.push_back(std::make_unique(Init, Copy, Reduce, VPSEL)); + return true; +} + +bool LowOverheadLoop::ValidateLiveOuts() { // We want to find out if the tail-predicated version of this loop will // produce the same values as the loop in its original form. For this to // be true, the newly inserted implicit predication must not change the @@ -652,9 +755,9 @@ SetVector FalseLanesUnknown; SmallPtrSet FalseLanesZero; SmallPtrSet Predicated; - MachineBasicBlock *MBB = ML.getHeader(); + MachineBasicBlock *Header = ML.getHeader(); - for (auto &MI : *MBB) { + for (auto &MI : *Header) { const MCInstrDesc &MCID = MI.getDesc(); uint64_t Flags = MCID.TSFlags; if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE) @@ -702,6 +805,7 @@ // stored and then we can work towards the leaves, hopefully adding more // instructions to Predicated. Successfully terminating the loop means that // all the unknown values have to found to be masked by predicated user(s). + SmallPtrSet NonPredicated; for (auto *MI : reverse(FalseLanesUnknown)) { for (auto &MO : MI->operands()) { if (!isRegInClass(MO, QPRs) || !MO.isDef()) @@ -709,39 +813,45 @@ if (!HasPredicatedUsers(MI, MO, Predicated)) { LLVM_DEBUG(dbgs() << "ARM Loops: Found an unknown def of : " << TRI.getRegAsmName(MO.getReg()) << " at " << *MI); - return false; + NonPredicated.insert(MI); + continue; } } // Any unknown false lanes have been masked away by the user(s). Predicated.insert(MI); } - // Collect Q-regs that are live in the exit blocks. We don't collect scalars - // because they won't be affected by lane predication. - SmallSet LiveOuts; + SmallPtrSet LiveOutMIs; + SmallPtrSet LiveOutUsers; SmallVector ExitBlocks; ML.getExitBlocks(ExitBlocks); - for (auto *MBB : ExitBlocks) - for (const MachineBasicBlock::RegisterMaskPair &RegMask : MBB->liveins()) - if (QPRs->contains(RegMask.PhysReg)) - LiveOuts.insert(RegMask.PhysReg); - - // Collect the instructions in the loop body that define the live-out values. - SmallPtrSet LiveMIs; assert(ML.getNumBlocks() == 1 && "Expected single block loop!"); - for (auto Reg : LiveOuts) - if (auto *MI = RDA.getLocalLiveOutMIDef(MBB, Reg)) - LiveMIs.insert(MI); + assert(ExitBlocks.size() == 1 && "Expected a single exit block"); + MachineBasicBlock *ExitBB = ExitBlocks.front(); + for (const MachineBasicBlock::RegisterMaskPair &RegMask : ExitBB->liveins()) { + // Check Q-regs that are live in the exit blocks. We don't collect scalars + // because they won't be affected by lane predication. + if (QPRs->contains(RegMask.PhysReg)) { + if (auto *MI = RDA.getLocalLiveOutMIDef(Header, RegMask.PhysReg)) + LiveOutMIs.insert(MI); + RDA.getLiveInUses(ExitBB, RegMask.PhysReg, LiveOutUsers); + } + } + + // If we have any non-predicated live-outs, they need to be part of a + // reduction that we can fixup later. + if (!NonPredicated.empty() && + !FindValidReduction(NonPredicated, LiveOutUsers)) + return false; - LLVM_DEBUG(dbgs() << "ARM Loops: Found loop live-outs:\n"; - for (auto *MI : LiveMIs) - dbgs() << " - " << *MI); // We've already validated that any VPT predication within the loop will be // equivalent when we perform the predication transformation; so we know that // any VPT predicated instruction is predicated upon VCTP. Any live-out - // instruction needs to be predicated, so check this here. - for (auto *MI : LiveMIs) - if (!isVectorPredicated(MI)) + // instruction needs to be predicated, so check this here. The instructions + // in NonPredicated have been found to be a reduction that we can ensure its + // legality. + for (auto *MI : LiveOutMIs) + if (!isVectorPredicated(MI) && !NonPredicated.count(MI)) return false; return true; @@ -949,14 +1059,12 @@ return nullptr; }; - LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI); + LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI, *TII); // Search the preheader for the start intrinsic. // FIXME: I don't see why we shouldn't be supporting multiple predecessors // with potentially multiple set.loop.iterations, so we need to enable this. - if (auto *Preheader = ML->getLoopPreheader()) - LoLoop.Start = SearchForStart(Preheader); - else if (auto *Preheader = MLI->findLoopPreheader(ML, true)) - LoLoop.Start = SearchForStart(Preheader); + if (LoLoop.Preheader) + LoLoop.Start = SearchForStart(LoLoop.Preheader); else return false; @@ -1210,6 +1318,61 @@ return &*MIB; } +void ARMLowOverheadLoops::FixupReductions(LowOverheadLoop &LoLoop) const { + LLVM_DEBUG(dbgs() << "ARM Loops: Fixing up reduction(s).\n"); + auto BuildMov = [this](MachineInstr &InsertPt, Register To, Register From) { + MachineBasicBlock *MBB = InsertPt.getParent(); + MachineInstrBuilder MIB = + BuildMI(*MBB, &InsertPt, InsertPt.getDebugLoc(), TII->get(ARM::MVE_VORR)); + MIB.addDef(To); + MIB.addReg(From); + MIB.addReg(From); + MIB.addImm(0); + MIB.addReg(0); + MIB.addReg(To); + LLVM_DEBUG(dbgs() << "ARM Loops: Inserted VMOV: " << *MIB); + }; + + for (auto &Reduction : LoLoop.Reductions) { + MachineInstr &Copy = Reduction->Copy; + MachineInstr &Reduce = Reduction->Reduce; + Register DestReg = Copy.getOperand(0).getReg(); + + // Change the initialiser if present + if (Reduction->Init) { + MachineInstr *Init = Reduction->Init; + + for (unsigned i = 0; i < Init->getNumOperands(); ++i) { + MachineOperand &MO = Init->getOperand(i); + if (MO.isReg() && MO.isUse() && MO.isTied() && + Init->findTiedOperandIdx(i) == 0) + Init->getOperand(i).setReg(DestReg); + } + Init->getOperand(0).setReg(DestReg); + LLVM_DEBUG(dbgs() << "ARM Loops: Changed init regs: " << *Init); + } else + BuildMov(LoLoop.Preheader->instr_back(), DestReg, Copy.getOperand(1).getReg()); + + // Change the reducing op to write to the register that is used to copy + // its value on the next iteration. Also update the tied-def operand. + Reduce.getOperand(0).setReg(DestReg); + Reduce.getOperand(5).setReg(DestReg); + LLVM_DEBUG(dbgs() << "ARM Loops: Changed reduction regs: " << Reduce); + + // Instead of a vpsel, just copy the register into the necessary one. + MachineInstr &VPSEL = Reduction->VPSEL; + if (VPSEL.getOperand(0).getReg() != DestReg) + BuildMov(VPSEL, VPSEL.getOperand(0).getReg(), DestReg); + + // Remove the unnecessary instructions. + LLVM_DEBUG(dbgs() << "ARM Loops: Removing:\n" + << " - " << Copy + << " - " << VPSEL << "\n"); + Copy.eraseFromParent(); + VPSEL.eraseFromParent(); + } +} + void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) { auto RemovePredicate = [](MachineInstr *MI) { LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI); @@ -1363,8 +1526,10 @@ RemoveDeadBranch(LoLoop.Start); LoLoop.End = ExpandLoopEnd(LoLoop); RemoveDeadBranch(LoLoop.End); - if (LoLoop.IsTailPredicationLegal()) + if (LoLoop.IsTailPredicationLegal()) { ConvertVPTBlocks(LoLoop); + FixupReductions(LoLoop); + } for (auto *I : LoLoop.ToRemove) { LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I); I->eraseFromParent(); Index: llvm/test/CodeGen/Thumb2/LowOverheadLoops/vector-arith-codegen.ll =================================================================== --- llvm/test/CodeGen/Thumb2/LowOverheadLoops/vector-arith-codegen.ll +++ llvm/test/CodeGen/Thumb2/LowOverheadLoops/vector-arith-codegen.ll @@ -9,28 +9,19 @@ ; CHECK-NEXT: moveq r0, #0 ; CHECK-NEXT: bxeq lr ; CHECK-NEXT: push {r7, lr} -; CHECK-NEXT: adds r3, r2, #3 -; CHECK-NEXT: vmov.i32 q0, #0x0 -; CHECK-NEXT: bic r3, r3, #3 -; CHECK-NEXT: sub.w r12, r3, #4 -; CHECK-NEXT: movs r3, #1 -; CHECK-NEXT: add.w lr, r3, r12, lsr #2 +; CHECK-NEXT: vmov.i32 q1, #0x0 ; CHECK-NEXT: movs r3, #0 -; CHECK-NEXT: dls lr, lr +; CHECK-NEXT: dlstp.32 lr, r2 ; CHECK-NEXT: .LBB0_1: @ %vector.body ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 -; CHECK-NEXT: vctp.32 r2 -; CHECK-NEXT: vmov q1, q0 -; CHECK-NEXT: vpstt -; CHECK-NEXT: vldrwt.u32 q0, [r0], #16 -; CHECK-NEXT: vldrwt.u32 q2, [r1], #16 +; CHECK-NEXT: vldrw.u32 q0, [r0], #16 +; CHECK-NEXT: vldrw.u32 q2, [r1], #16 ; CHECK-NEXT: adds r3, #4 ; CHECK-NEXT: vmul.i32 q0, q2, q0 -; CHECK-NEXT: subs r2, #4 -; CHECK-NEXT: vadd.i32 q0, q0, q1 -; CHECK-NEXT: le lr, .LBB0_1 +; CHECK-NEXT: vadd.i32 q1, q0, q1 +; CHECK-NEXT: letp lr, .LBB0_1 ; CHECK-NEXT: @ %bb.2: @ %middle.block -; CHECK-NEXT: vpsel q0, q0, q1 +; CHECK-NEXT: vmov q0, q1 ; CHECK-NEXT: vaddv.u32 r0, q0 ; CHECK-NEXT: pop {r7, pc} entry: @@ -85,26 +76,17 @@ ; CHECK-NEXT: moveq r0, #0 ; CHECK-NEXT: bxeq lr ; CHECK-NEXT: push {r7, lr} -; CHECK-NEXT: adds r1, r2, #3 -; CHECK-NEXT: movs r3, #1 -; CHECK-NEXT: bic r1, r1, #3 -; CHECK-NEXT: vmov.i32 q0, #0x0 -; CHECK-NEXT: subs r1, #4 -; CHECK-NEXT: add.w lr, r3, r1, lsr #2 +; CHECK-NEXT: vmov.i32 q1, #0x0 ; CHECK-NEXT: movs r1, #0 -; CHECK-NEXT: dls lr, lr +; CHECK-NEXT: dlstp.32 lr, r2 ; CHECK-NEXT: .LBB1_1: @ %vector.body ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 -; CHECK-NEXT: vctp.32 r2 -; CHECK-NEXT: vmov q1, q0 -; CHECK-NEXT: vpst -; CHECK-NEXT: vldrwt.u32 q0, [r0], #16 +; CHECK-NEXT: vldrw.u32 q0, [r0], #16 ; CHECK-NEXT: adds r1, #4 -; CHECK-NEXT: subs r2, #4 -; CHECK-NEXT: vadd.i32 q0, q0, q1 -; CHECK-NEXT: le lr, .LBB1_1 +; CHECK-NEXT: vadd.i32 q1, q0, q1 +; CHECK-NEXT: letp lr, .LBB1_1 ; CHECK-NEXT: @ %bb.2: @ %middle.block -; CHECK-NEXT: vpsel q0, q0, q1 +; CHECK-NEXT: vmov q0, q1 ; CHECK-NEXT: vaddv.u32 r0, q0 ; CHECK-NEXT: pop {r7, pc} entry: @@ -155,26 +137,17 @@ ; CHECK-NEXT: moveq r0, #0 ; CHECK-NEXT: bxeq lr ; CHECK-NEXT: push {r7, lr} -; CHECK-NEXT: adds r1, r2, #3 -; CHECK-NEXT: movs r3, #1 -; CHECK-NEXT: bic r1, r1, #3 -; CHECK-NEXT: vmov.i32 q0, #0x0 -; CHECK-NEXT: subs r1, #4 -; CHECK-NEXT: add.w lr, r3, r1, lsr #2 +; CHECK-NEXT: vmov.i32 q1, #0x0 ; CHECK-NEXT: movs r1, #0 -; CHECK-NEXT: dls lr, lr +; CHECK-NEXT: dlstp.32 lr, r2 ; CHECK-NEXT: .LBB2_1: @ %vector.body ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 -; CHECK-NEXT: vctp.32 r2 -; CHECK-NEXT: vmov q1, q0 -; CHECK-NEXT: vpst -; CHECK-NEXT: vldrwt.u32 q0, [r0], #16 +; CHECK-NEXT: vldrw.u32 q0, [r0], #16 ; CHECK-NEXT: adds r1, #4 -; CHECK-NEXT: subs r2, #4 -; CHECK-NEXT: vadd.i32 q0, q0, q1 -; CHECK-NEXT: le lr, .LBB2_1 +; CHECK-NEXT: vadd.i32 q1, q0, q1 +; CHECK-NEXT: letp lr, .LBB2_1 ; CHECK-NEXT: @ %bb.2: @ %middle.block -; CHECK-NEXT: vpsel q0, q0, q1 +; CHECK-NEXT: vmov q0, q1 ; CHECK-NEXT: vaddv.u32 r0, q0 ; CHECK-NEXT: pop {r7, pc} entry: