diff --git a/llvm/lib/Target/ARM/ARMInstrVFP.td b/llvm/lib/Target/ARM/ARMInstrVFP.td --- a/llvm/lib/Target/ARM/ARMInstrVFP.td +++ b/llvm/lib/Target/ARM/ARMInstrVFP.td @@ -2500,8 +2500,7 @@ "vmrs", "\t$Rt, fpcxts", []>; } - let Predicates = [HasV8_1MMainline, HasMVEInt], - D=MVEDomain, validForTailPredication=1 in { + let Predicates = [HasV8_1MMainline, HasMVEInt] in { // System level VPR/P0 -> GPR let Uses = [VPR] in def VMRS_VPR : MovFromVFP<0b1100 /* vpr */, (outs GPR:$Rt), (ins), @@ -2861,8 +2860,7 @@ } } -let Predicates = [HasV8_1MMainline, HasMVEInt], - D=MVEDomain, validForTailPredication=1 in { +let Predicates = [HasV8_1MMainline, HasMVEInt] in { let Uses = [VPR] in { defm VSTR_VPR : vfp_vstrldr_sysreg<0b0,0b1100, "vpr">; } diff --git a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp --- a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp +++ b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp @@ -73,6 +73,29 @@ #define DEBUG_TYPE "arm-low-overhead-loops" #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass" +static bool isVectorPredicated(MachineInstr *MI) { + int PIdx = llvm::findFirstVPTPredOperandIdx(*MI); + return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR; +} + +static bool isVectorPredicate(MachineInstr *MI) { + return MI->findRegisterDefOperandIdx(ARM::VPR) != -1; +} + +static bool hasVPRUse(MachineInstr *MI) { + return MI->findRegisterUseOperandIdx(ARM::VPR) != -1; +} + +static bool isDomainMVE(MachineInstr *MI) { + uint64_t Domain = MI->getDesc().TSFlags & ARMII::DomainMask; + return Domain == ARMII::DomainMVE; +} + +static bool shouldInspect(MachineInstr &MI) { + return isDomainMVE(&MI) || isVectorPredicate(&MI) || + hasVPRUse(&MI); +} + namespace { using InstSet = SmallPtrSetImpl; @@ -563,11 +586,6 @@ return true; } -static bool isVectorPredicated(MachineInstr *MI) { - int PIdx = llvm::findFirstVPTPredOperandIdx(*MI); - return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR; -} - static bool isRegInClass(const MachineOperand &MO, const TargetRegisterClass *Class) { return MO.isReg() && MO.getReg() && Class->contains(MO.getReg()); @@ -703,9 +721,7 @@ MachineBasicBlock *Header = ML.getHeader(); for (auto &MI : *Header) { - const MCInstrDesc &MCID = MI.getDesc(); - uint64_t Flags = MCID.TSFlags; - if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE) + if (!shouldInspect(MI)) continue; if (isVCTP(&MI) || isVPTOpcode(MI.getOpcode())) @@ -854,9 +870,7 @@ if (CannotTailPredicate) return false; - const MCInstrDesc &MCID = MI->getDesc(); - uint64_t Flags = MCID.TSFlags; - if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE) + if (!shouldInspect(*MI)) return true; if (MI->getOpcode() == ARM::MVE_VPSEL || @@ -901,6 +915,9 @@ return true; } + // Inspect uses first so that any instructions that alter the VPR don't + // alter the predicate upon themselves. + const MCInstrDesc &MCID = MI->getDesc(); bool IsUse = false; bool IsDef = false; for (int i = MI->getNumOperands() - 1; i >= 0; --i) { @@ -941,8 +958,11 @@ // If we find an instruction that has been marked as not valid for tail // predication, only allow the instruction if it's contained within a valid // VPT block. - if ((Flags & ARMII::ValidForTailPredication) == 0) { - LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI); + bool RequiresExplicitPredication = + (MCID.TSFlags & ARMII::ValidForTailPredication) == 0; + if (isDomainMVE(MI) && RequiresExplicitPredication) { + LLVM_DEBUG(if (!IsUse) + dbgs() << "ARM Loops: Can't tail predicate: " << *MI); return IsUse; } diff --git a/llvm/test/CodeGen/Thumb2/LowOverheadLoops/vctp-in-vpt.mir b/llvm/test/CodeGen/Thumb2/LowOverheadLoops/vctp-in-vpt.mir --- a/llvm/test/CodeGen/Thumb2/LowOverheadLoops/vctp-in-vpt.mir +++ b/llvm/test/CodeGen/Thumb2/LowOverheadLoops/vctp-in-vpt.mir @@ -426,23 +426,30 @@ ; CHECK: bb.1.bb3: ; CHECK: successors: %bb.2(0x80000000) ; CHECK: liveins: $r0, $r1, $r2, $r3 + ; CHECK: renamable $r12 = t2ADDri renamable $r2, 3, 14 /* CC::al */, $noreg, $noreg + ; CHECK: renamable $lr = t2MOVi 1, 14 /* CC::al */, $noreg, $noreg + ; CHECK: renamable $r12 = t2BICri killed renamable $r12, 3, 14 /* CC::al */, $noreg, $noreg ; CHECK: $vpr = VMSR_P0 killed $r3, 14 /* CC::al */, $noreg + ; CHECK: renamable $r12 = t2SUBri killed renamable $r12, 4, 14 /* CC::al */, $noreg, $noreg ; CHECK: VSTR_P0_off killed renamable $vpr, $sp, 0, 14 /* CC::al */, $noreg :: (store 4 into %stack.0) ; CHECK: $r3 = tMOVr $r0, 14 /* CC::al */, $noreg - ; CHECK: $lr = MVE_DLSTP_32 killed renamable $r2 + ; CHECK: renamable $lr = nuw nsw t2ADDrs killed renamable $lr, killed renamable $r12, 19, 14 /* CC::al */, $noreg, $noreg + ; CHECK: $lr = t2DLS killed renamable $lr ; CHECK: bb.2.bb9: ; CHECK: successors: %bb.2(0x7c000000), %bb.3(0x04000000) - ; CHECK: liveins: $lr, $r0, $r1, $r3 + ; CHECK: liveins: $lr, $r0, $r1, $r2, $r3 ; CHECK: renamable $vpr = VLDR_P0_off $sp, 0, 14 /* CC::al */, $noreg :: (load 4 from %stack.0) - ; CHECK: MVE_VPST 4, implicit $vpr + ; CHECK: MVE_VPST 2, implicit $vpr + ; CHECK: renamable $vpr = MVE_VCTP32 renamable $r2, 1, killed renamable $vpr ; CHECK: renamable $r1, renamable $q0 = MVE_VLDRWU32_post killed renamable $r1, 16, 1, renamable $vpr ; CHECK: renamable $r3, renamable $q1 = MVE_VLDRWU32_post killed renamable $r3, 16, 1, killed renamable $vpr ; CHECK: $vpr = VMSR_P0 $r3, 14 /* CC::al */, $noreg + ; CHECK: renamable $r2, dead $cpsr = tSUBi8 killed renamable $r2, 4, 14 /* CC::al */, $noreg ; CHECK: renamable $q0 = nsw MVE_VMULi32 killed renamable $q1, killed renamable $q0, 0, $noreg, undef renamable $q0 ; CHECK: MVE_VPST 8, implicit $vpr ; CHECK: MVE_VSTRWU32 killed renamable $q0, killed renamable $r0, 0, 1, killed renamable $vpr ; CHECK: $r0 = tMOVr $r3, 14 /* CC::al */, $noreg - ; CHECK: $lr = MVE_LETP killed renamable $lr, %bb.2 + ; CHECK: $lr = t2LEUpdate killed renamable $lr, %bb.2 ; CHECK: bb.3.bb27: ; CHECK: $sp = tADDspi $sp, 1, 14 /* CC::al */, $noreg ; CHECK: tPOP_RET 14 /* CC::al */, $noreg, def $r7, def $pc